diff options
author | Andras Becsi <andras.becsi@digia.com> | 2013-12-11 21:33:03 +0100 |
---|---|---|
committer | Andras Becsi <andras.becsi@digia.com> | 2013-12-13 12:34:07 +0100 |
commit | f2a33ff9cbc6d19943f1c7fbddd1f23d23975577 (patch) | |
tree | 0586a32aa390ade8557dfd6b4897f43a07449578 /chromium/net | |
parent | 5362912cdb5eea702b68ebe23702468d17c3017a (diff) | |
download | qtwebengine-chromium-f2a33ff9cbc6d19943f1c7fbddd1f23d23975577.tar.gz |
Update Chromium to branch 1650 (31.0.1650.63)
Change-Id: I57d8c832eaec1eb2364e0a8e7352a6dd354db99f
Reviewed-by: Jocelyn Turcotte <jocelyn.turcotte@digia.com>
Diffstat (limited to 'chromium/net')
586 files changed, 28857 insertions, 14292 deletions
diff --git a/chromium/net/OWNERS b/chromium/net/OWNERS index 18f14625033..87e11faef93 100644 --- a/chromium/net/OWNERS +++ b/chromium/net/OWNERS @@ -1,4 +1,5 @@ agl@chromium.org +akalin@chromium.org asanka@chromium.org cbentzel@chromium.org eroman@chromium.org diff --git a/chromium/net/android/java/CertVerifyResultAndroid.template b/chromium/net/android/java/CertVerifyResultAndroid.template deleted file mode 100644 index b19e937fcb9..00000000000 --- a/chromium/net/android/java/CertVerifyResultAndroid.template +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -public class CertVerifyResultAndroid { -#define CERT_VERIFY_RESULT_ANDROID(name, value) public static final int VERIFY_##name = value; -#include "net/android/cert_verify_result_android_list.h" -} diff --git a/chromium/net/android/java/CertificateMimeType.template b/chromium/net/android/java/CertificateMimeType.template deleted file mode 100644 index 5a21171e88b..00000000000 --- a/chromium/net/android/java/CertificateMimeType.template +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -public class CertificateMimeType { -#define CERTIFICATE_MIME_TYPE(name, value) public static final int name = value; -#include "net/base/mime_util_certificate_type_list.h" -#undef CERTIFICATE_MIME_TYPE -} diff --git a/chromium/net/android/java/NetError.template b/chromium/net/android/java/NetError.template deleted file mode 100644 index f6c16617b3b..00000000000 --- a/chromium/net/android/java/NetError.template +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -public class NetError { -#define NET_ERROR(name, value) public static final int ERR_##name = value; -#include "net/base/net_error_list.h" -} diff --git a/chromium/net/android/java/PrivateKeyType.template b/chromium/net/android/java/PrivateKeyType.template deleted file mode 100644 index aa7f76f8881..00000000000 --- a/chromium/net/android/java/PrivateKeyType.template +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -public class PrivateKeyType { -#define DEFINE_PRIVATE_KEY_TYPE(name,value) public static final int name = value; -#include "net/android/private_key_type_list.h" -} diff --git a/chromium/net/android/java/src/org/chromium/net/AndroidKeyStore.java b/chromium/net/android/java/src/org/chromium/net/AndroidKeyStore.java deleted file mode 100644 index de5d8f2d5f0..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/AndroidKeyStore.java +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright (c) 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.util.Log; - -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; -import java.math.BigInteger; -import java.security.interfaces.DSAKey; -import java.security.interfaces.DSAPrivateKey; -import java.security.interfaces.DSAParams; -import java.security.interfaces.ECKey; -import java.security.interfaces.ECPrivateKey; -import java.security.interfaces.RSAKey; -import java.security.interfaces.RSAPrivateKey; -import java.security.NoSuchAlgorithmException; -import java.security.PrivateKey; -import java.security.Signature; -import java.security.spec.ECParameterSpec; - -import org.chromium.base.CalledByNative; -import org.chromium.base.JNINamespace; -import org.chromium.net.PrivateKeyType;; - -@JNINamespace("net::android") -public class AndroidKeyStore { - - private static final String TAG = "AndroidKeyStore"; - - //////////////////////////////////////////////////////////////////// - // - // Message signing support. - - /** - * Returns the public modulus of a given RSA private key as a byte - * buffer. - * This can be used by native code to convert the modulus into - * an OpenSSL BIGNUM object. Required to craft a custom native RSA - * object where RSA_size() works as expected. - * - * @param key A PrivateKey instance, must implement RSAKey. - * @return A byte buffer corresponding to the modulus. This is - * big-endian representation of a BigInteger. - */ - @CalledByNative - public static byte[] getRSAKeyModulus(PrivateKey key) { - if (key instanceof RSAKey) { - return ((RSAKey) key).getModulus().toByteArray(); - } else { - Log.w(TAG, "Not a RSAKey instance!"); - return null; - } - } - - /** - * Returns the 'Q' parameter of a given DSA private key as a byte - * buffer. - * This can be used by native code to convert it into an OpenSSL BIGNUM - * object where DSA_size() works as expected. - * - * @param key A PrivateKey instance. Must implement DSAKey. - * @return A byte buffer corresponding to the Q parameter. This is - * a big-endian representation of a BigInteger. - */ - @CalledByNative - public static byte[] getDSAKeyParamQ(PrivateKey key) { - if (key instanceof DSAKey) { - DSAParams params = ((DSAKey) key).getParams(); - return params.getQ().toByteArray(); - } else { - Log.w(TAG, "Not a DSAKey instance!"); - return null; - } - } - - /** - * Returns the 'order' parameter of a given ECDSA private key as a - * a byte buffer. - * @param key A PrivateKey instance. Must implement ECKey. - * @return A byte buffer corresponding to the 'order' parameter. - * This is a big-endian representation of a BigInteger. - */ - @CalledByNative - public static byte[] getECKeyOrder(PrivateKey key) { - if (key instanceof ECKey) { - ECParameterSpec params = ((ECKey) key).getParams(); - return params.getOrder().toByteArray(); - } else { - Log.w(TAG, "Not an ECKey instance!"); - return null; - } - } - - /** - * Returns the encoded data corresponding to a given PrivateKey. - * Note that this will fail for platform keys on Android 4.0.4 - * and higher. It can be used on 4.0.3 and older platforms to - * route around the platform bug described below. - * @param key A PrivateKey instance - * @return encoded key as PKCS#8 byte array, can be null. - */ - @CalledByNative - public static byte[] getPrivateKeyEncodedBytes(PrivateKey key) { - return key.getEncoded(); - } - - /** - * Sign a given message with a given PrivateKey object. This method - * shall only be used to implement signing in the context of SSL - * client certificate support. - * - * The message will actually be a hash, computed by OpenSSL itself, - * depending on the type of the key. The result should match exactly - * what the vanilla implementations of the following OpenSSL function - * calls do: - * - * - For a RSA private key, this should be equivalent to calling - * RSA_private_encrypt(..., RSA_PKCS1_PADDING), i.e. it must - * generate a raw RSA signature. The message must be either a - * combined, 36-byte MD5+SHA1 message digest or a DigestInfo - * value wrapping a message digest. - * - * - For a DSA and ECDSA private keys, this should be equivalent to - * calling DSA_sign(0,...) and ECDSA_sign(0,...) respectively. The - * message must be a hash and the function shall compute a direct - * DSA/ECDSA signature for it. - * - * @param privateKey The PrivateKey handle. - * @param message The message to sign. - * @return signature as a byte buffer. - * - * Important: Due to a platform bug, this function will always fail on - * Android < 4.2 for RSA PrivateKey objects. See the - * getOpenSSLHandleForPrivateKey() below for work-around. - */ - @CalledByNative - public static byte[] rawSignDigestWithPrivateKey(PrivateKey privateKey, - byte[] message) { - // Get the Signature for this key. - Signature signature = null; - // Hint: Algorithm names come from: - // http://docs.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html - try { - if (privateKey instanceof RSAPrivateKey) { - // IMPORTANT: Due to a platform bug, this will throw NoSuchAlgorithmException - // on Android 4.0.x and 4.1.x. Fixed in 4.2 and higher. - // See https://android-review.googlesource.com/#/c/40352/ - signature = Signature.getInstance("NONEwithRSA"); - } else if (privateKey instanceof DSAPrivateKey) { - signature = Signature.getInstance("NONEwithDSA"); - } else if (privateKey instanceof ECPrivateKey) { - signature = Signature.getInstance("NONEwithECDSA"); - } - } catch (NoSuchAlgorithmException e) { - ; - } - - if (signature == null) { - Log.e(TAG, "Unsupported private key algorithm: " + privateKey.getAlgorithm()); - return null; - } - - // Sign the message. - try { - signature.initSign(privateKey); - signature.update(message); - return signature.sign(); - } catch (Exception e) { - Log.e(TAG, "Exception while signing message with " + privateKey.getAlgorithm() + - " private key: " + e); - return null; - } - } - - /** - * Return the type of a given PrivateKey object. This is an integer - * that maps to one of the values defined by org.chromium.net.PrivateKeyType, - * which is itself auto-generated from net/android/private_key_type_list.h - * @param privateKey The PrivateKey handle - * @return key type, or PrivateKeyType.INVALID if unknown. - */ - @CalledByNative - public static int getPrivateKeyType(PrivateKey privateKey) { - if (privateKey instanceof RSAPrivateKey) - return PrivateKeyType.RSA; - if (privateKey instanceof DSAPrivateKey) - return PrivateKeyType.DSA; - if (privateKey instanceof ECPrivateKey) - return PrivateKeyType.ECDSA; - else - return PrivateKeyType.INVALID; - } - - /** - * Return the system EVP_PKEY handle corresponding to a given PrivateKey - * object, obtained through reflection. - * - * This shall only be used when the "NONEwithRSA" signature is not - * available, as described in rawSignDigestWithPrivateKey(). I.e. - * never use this on Android 4.2 or higher. - * - * This can only work in Android 4.0.4 and higher, for older versions - * of the platform (e.g. 4.0.3), there is no system OpenSSL EVP_PKEY, - * but the private key contents can be retrieved directly with - * the getEncoded() method. - * - * This assumes that the target device uses a vanilla AOSP - * implementation of its java.security classes, which is also - * based on OpenSSL (fortunately, no OEM has apperently changed to - * a different implementation, according to the Android team). - * - * Note that the object returned was created with the platform version - * of OpenSSL, and _not_ the one that comes with Chromium. Whether the - * object can be used safely with the Chromium OpenSSL library depends - * on differences between their actual ABI / implementation details. - * - * To better understand what's going on below, please refer to the - * following source files in the Android 4.0.4 and 4.1 source trees: - * libcore/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/OpenSSLRSAPrivateKey.java - * libcore/luni/src/main/native/org_apache_harmony_xnet_provider_jsse_NativeCrypto.cpp - * - * @param privateKey The PrivateKey handle. - * @return The EVP_PKEY handle, as a 32-bit integer (0 if not available) - */ - @CalledByNative - public static int getOpenSSLHandleForPrivateKey(PrivateKey privateKey) { - // Sanity checks - if (privateKey == null) { - Log.e(TAG, "privateKey == null"); - return 0; - } - if (!(privateKey instanceof RSAPrivateKey)) { - Log.e(TAG, "does not implement RSAPrivateKey"); - return 0; - } - // First, check that this is a proper instance of OpenSSLRSAPrivateKey - // or one of its sub-classes. - Class<?> superClass; - try { - superClass = Class.forName( - "org.apache.harmony.xnet.provider.jsse.OpenSSLRSAPrivateKey"); - } catch (Exception e) { - // This may happen if the target device has a completely different - // implementation of the java.security APIs, compared to vanilla - // Android. Highly unlikely, but still possible. - Log.e(TAG, "Cannot find system OpenSSLRSAPrivateKey class: " + e); - return 0; - } - if (!superClass.isInstance(privateKey)) { - // This may happen if the PrivateKey was not created by the "AndroidOpenSSL" - // provider, which should be the default. That could happen if an OEM decided - // to implement a different default provider. Also highly unlikely. - Log.e(TAG, "Private key is not an OpenSSLRSAPrivateKey instance, its class name is:" + - privateKey.getClass().getCanonicalName()); - return 0; - } - - try { - // Use reflection to invoke the 'getOpenSSLKey()' method on - // the private key. This returns another Java object that wraps - // a native EVP_PKEY. Note that the method is final, so calling - // the superclass implementation is ok. - Method getKey = superClass.getDeclaredMethod("getOpenSSLKey"); - getKey.setAccessible(true); - Object opensslKey = null; - try { - opensslKey = getKey.invoke(privateKey); - } finally { - getKey.setAccessible(false); - } - if (opensslKey == null) { - // Bail when detecting OEM "enhancement". - Log.e(TAG, "getOpenSSLKey() returned null"); - return 0; - } - - // Use reflection to invoke the 'getPkeyContext' method on the - // result of the getOpenSSLKey(). This is an 32-bit integer - // which is the address of an EVP_PKEY object. - Method getPkeyContext; - try { - getPkeyContext = opensslKey.getClass().getDeclaredMethod("getPkeyContext"); - } catch (Exception e) { - // Bail here too, something really not working as expected. - Log.e(TAG, "No getPkeyContext() method on OpenSSLKey member:" + e); - return 0; - } - getPkeyContext.setAccessible(true); - int evp_pkey = 0; - try { - evp_pkey = (Integer) getPkeyContext.invoke(opensslKey); - } finally { - getPkeyContext.setAccessible(false); - } - if (evp_pkey == 0) { - // The PrivateKey is probably rotten for some reason. - Log.e(TAG, "getPkeyContext() returned null"); - } - return evp_pkey; - - } catch (Exception e) { - Log.e(TAG, "Exception while trying to retrieve system EVP_PKEY handle: " + e); - return 0; - } - } -} diff --git a/chromium/net/android/java/src/org/chromium/net/AndroidNetworkLibrary.java b/chromium/net/android/java/src/org/chromium/net/AndroidNetworkLibrary.java deleted file mode 100644 index 95752cca8b2..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/AndroidNetworkLibrary.java +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.content.ActivityNotFoundException; -import android.content.Context; -import android.content.Intent; -import android.security.KeyChain; -import android.util.Log; - -import org.chromium.base.CalledByNative; -import org.chromium.base.CalledByNativeUnchecked; -import org.chromium.net.CertVerifyResultAndroid; -import org.chromium.net.CertificateMimeType; - -import java.net.Inet6Address; -import java.net.InetAddress; -import java.net.NetworkInterface; -import java.net.SocketException; -import java.net.URLConnection; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; -import java.util.Enumeration; - -/** - * This class implements net utilities required by the net component. - */ -class AndroidNetworkLibrary { - - private static final String TAG = "AndroidNetworkLibrary"; - - /** - * Stores the key pair through the CertInstaller activity. - * @param context: current application context. - * @param public_key: The public key bytes as DER-encoded SubjectPublicKeyInfo (X.509) - * @param private_key: The private key as DER-encoded PrivateKeyInfo (PKCS#8). - * @return: true on success, false on failure. - * - * Note that failure means that the function could not launch the CertInstaller - * activity. Whether the keys are valid or properly installed will be indicated - * by the CertInstaller UI itself. - */ - @CalledByNative - static public boolean storeKeyPair(Context context, byte[] public_key, byte[] private_key) { - // TODO(digit): Use KeyChain official extra values to pass the public and private - // keys when they're available. The "KEY" and "PKEY" hard-coded constants were taken - // from the platform sources, since there are no official KeyChain.EXTRA_XXX definitions - // for them. b/5859651 - try { - Intent intent = KeyChain.createInstallIntent(); - intent.putExtra("PKEY", private_key); - intent.putExtra("KEY", public_key); - intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK); - context.startActivity(intent); - return true; - } catch (ActivityNotFoundException e) { - Log.w(TAG, "could not store key pair: " + e); - } - return false; - } - - /** - * Adds a cryptographic file (User certificate, a CA certificate or - * PKCS#12 keychain) through the system's CertInstaller activity. - * - * @param context: current application context. - * @param cert_type: cryptographic file type. E.g. CertificateMimeType.X509_USER_CERT - * @param data: certificate/keychain data bytes. - * @return true on success, false on failure. - * - * Note that failure only indicates that the function couldn't launch the - * CertInstaller activity, not that the certificate/keychain was properly - * installed to the keystore. - */ - @CalledByNative - static public boolean storeCertificate(Context context, int cert_type, byte[] data) { - try { - Intent intent = KeyChain.createInstallIntent(); - intent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK); - - switch (cert_type) { - case CertificateMimeType.X509_USER_CERT: - case CertificateMimeType.X509_CA_CERT: - intent.putExtra(KeyChain.EXTRA_CERTIFICATE, data); - break; - - case CertificateMimeType.PKCS12_ARCHIVE: - intent.putExtra(KeyChain.EXTRA_PKCS12, data); - break; - - default: - Log.w(TAG, "invalid certificate type: " + cert_type); - return false; - } - context.startActivity(intent); - return true; - } catch (ActivityNotFoundException e) { - Log.w(TAG, "could not store crypto file: " + e); - } - return false; - } - - /** - * @return the mime type (if any) that is associated with the file - * extension. Returns null if no corresponding mime type exists. - */ - @CalledByNative - static public String getMimeTypeFromExtension(String extension) { - return URLConnection.guessContentTypeFromName("foo." + extension); - } - - /** - * @return true if it can determine that only loopback addresses are - * configured. i.e. if only 127.0.0.1 and ::1 are routable. Also - * returns false if it cannot determine this. - */ - @CalledByNative - static public boolean haveOnlyLoopbackAddresses() { - Enumeration<NetworkInterface> list = null; - try { - list = NetworkInterface.getNetworkInterfaces(); - if (list == null) return false; - } catch (Exception e) { - Log.w(TAG, "could not get network interfaces: " + e); - return false; - } - - while (list.hasMoreElements()) { - NetworkInterface netIf = list.nextElement(); - try { - if (netIf.isUp() && !netIf.isLoopback()) return false; - } catch (SocketException e) { - continue; - } - } - return true; - } - - /** - * @return the network interfaces list (if any) string. The items in - * the list string are delimited by a semicolon ";", each item - * is a network interface name and address pair and formatted - * as "name,address". e.g. - * eth0,10.0.0.2;eth0,fe80::5054:ff:fe12:3456 - * represents a network list string which containts two items. - */ - @CalledByNative - static public String getNetworkList() { - Enumeration<NetworkInterface> list = null; - try { - list = NetworkInterface.getNetworkInterfaces(); - if (list == null) return ""; - } catch (SocketException e) { - Log.w(TAG, "Unable to get network interfaces: " + e); - return ""; - } - - StringBuilder result = new StringBuilder(); - while (list.hasMoreElements()) { - NetworkInterface netIf = list.nextElement(); - try { - // Skip loopback interfaces, and ones which are down. - if (!netIf.isUp() || netIf.isLoopback()) - continue; - Enumeration<InetAddress> addressList = netIf.getInetAddresses(); - while (addressList.hasMoreElements()) { - InetAddress address = addressList.nextElement(); - // Skip loopback addresses configured on non-loopback interfaces. - if (address.isLoopbackAddress()) - continue; - StringBuilder addressString = new StringBuilder(); - addressString.append(netIf.getName()); - addressString.append(","); - - String ipAddress = address.getHostAddress(); - if (address instanceof Inet6Address && ipAddress.contains("%")) { - ipAddress = ipAddress.substring(0, ipAddress.lastIndexOf("%")); - } - addressString.append(ipAddress); - - if (result.length() != 0) - result.append(";"); - result.append(addressString.toString()); - } - } catch (SocketException e) { - continue; - } - } - return result.toString(); - } - - /** - * Validate the server's certificate chain is trusted. - * - * @param certChain The ASN.1 DER encoded bytes for certificates. - * @param authType The key exchange algorithm name (e.g. RSA) - * @return Android certificate verification result code. - */ - @CalledByNative - public static int verifyServerCertificates(byte[][] certChain, String authType) { - try { - return X509Util.verifyServerCertificates(certChain, authType); - } catch (KeyStoreException e) { - return CertVerifyResultAndroid.VERIFY_FAILED; - } catch (NoSuchAlgorithmException e) { - return CertVerifyResultAndroid.VERIFY_FAILED; - } - } - - /** - * Adds a test root certificate to the local trust store. - * @param rootCert DER encoded bytes of the certificate. - */ - @CalledByNativeUnchecked - public static void addTestRootCertificate(byte[] rootCert) throws CertificateException, - KeyStoreException, NoSuchAlgorithmException { - X509Util.addTestRootCertificate(rootCert); - } - - /** - * Removes all test root certificates added by |addTestRootCertificate| calls from the local - * trust store. - */ - @CalledByNativeUnchecked - public static void clearTestRootCertificates() throws NoSuchAlgorithmException, - CertificateException, KeyStoreException { - X509Util.clearTestRootCertificates(); - } -} diff --git a/chromium/net/android/java/src/org/chromium/net/GURLUtils.java b/chromium/net/android/java/src/org/chromium/net/GURLUtils.java deleted file mode 100644 index 719ddeabb23..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/GURLUtils.java +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import org.chromium.base.JNINamespace; - -/** - * Class to access the GURL library from java. - */ -@JNINamespace("net") -public final class GURLUtils { - - /** - * Get the origin of an url: Ex getOrigin("http://www.example.com:8080/index.html?bar=foo") - * would return "http://www.example.com:8080". It will return an empty string for an - * invalid url. - * - * @return The origin of the url - */ - public static String getOrigin(String url) { - return nativeGetOrigin(url); - } - - /** - * Get the scheme of the url (e.g. http, https, file). The returned string - * contains everything before the "://". - * - * @return The scheme of the url. - */ - public static String getScheme(String url) { - return nativeGetScheme(url); - } - - private static native String nativeGetOrigin(String url); - private static native String nativeGetScheme(String url); -} diff --git a/chromium/net/android/java/src/org/chromium/net/NetworkChangeNotifier.java b/chromium/net/android/java/src/org/chromium/net/NetworkChangeNotifier.java deleted file mode 100644 index a5de98313c2..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/NetworkChangeNotifier.java +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.content.Context; - -import org.chromium.base.CalledByNative; -import org.chromium.base.JNINamespace; -import org.chromium.base.NativeClassQualifiedName; -import org.chromium.base.ObserverList; - -import java.util.ArrayList; - -/** - * Triggers updates to the underlying network state in Chrome. - * - * By default, connectivity is assumed and changes must pushed from the embedder via the - * forceConnectivityState function. - * Embedders may choose to have this class auto-detect changes in network connectivity by invoking - * the setAutoDetectConnectivityState function. - * - * WARNING: This class is not thread-safe. - */ -@JNINamespace("net") -public class NetworkChangeNotifier { - /** - * Alerted when the connection type of the network changes. - * The alert is fired on the UI thread. - */ - public interface ConnectionTypeObserver { - public void onConnectionTypeChanged(int connectionType); - } - - // These constants must always match the ones in network_change_notifier.h. - public static final int CONNECTION_UNKNOWN = 0; - public static final int CONNECTION_ETHERNET = 1; - public static final int CONNECTION_WIFI = 2; - public static final int CONNECTION_2G = 3; - public static final int CONNECTION_3G = 4; - public static final int CONNECTION_4G = 5; - public static final int CONNECTION_NONE = 6; - - private final Context mContext; - private final ArrayList<Integer> mNativeChangeNotifiers; - private final ObserverList<ConnectionTypeObserver> mConnectionTypeObservers; - private NetworkChangeNotifierAutoDetect mAutoDetector; - private int mCurrentConnectionType = CONNECTION_UNKNOWN; - - private static NetworkChangeNotifier sInstance; - - private NetworkChangeNotifier(Context context) { - mContext = context; - mNativeChangeNotifiers = new ArrayList<Integer>(); - mConnectionTypeObservers = new ObserverList<ConnectionTypeObserver>(); - } - - /** - * Initializes the singleton once. - */ - @CalledByNative - public static NetworkChangeNotifier init(Context context) { - if (sInstance == null) { - sInstance = new NetworkChangeNotifier(context); - } - return sInstance; - } - - public static boolean isInitialized() { - return sInstance != null; - } - - static void resetInstanceForTests(Context context) { - sInstance = new NetworkChangeNotifier(context); - } - - @CalledByNative - public int getCurrentConnectionType() { - return mCurrentConnectionType; - } - - /** - * Adds a native-side observer. - */ - @CalledByNative - public void addNativeObserver(int nativeChangeNotifier) { - mNativeChangeNotifiers.add(nativeChangeNotifier); - } - - /** - * Removes a native-side observer. - */ - @CalledByNative - public void removeNativeObserver(int nativeChangeNotifier) { - // Please keep the cast performing the boxing below. It ensures that the right method - // overload is used. ArrayList<T> has both remove(int index) and remove(T element). - mNativeChangeNotifiers.remove((Integer) nativeChangeNotifier); - } - - /** - * Returns the singleton instance. - */ - public static NetworkChangeNotifier getInstance() { - assert sInstance != null; - return sInstance; - } - - /** - * Enables auto detection of the current network state based on notifications from the system. - * Note that passing true here requires the embedding app have the platform ACCESS_NETWORK_STATE - * permission. - * - * @param shouldAutoDetect true if the NetworkChangeNotifier should listen for system changes in - * network connectivity. - */ - public static void setAutoDetectConnectivityState(boolean shouldAutoDetect) { - getInstance().setAutoDetectConnectivityStateInternal(shouldAutoDetect); - } - - private void destroyAutoDetector() { - if (mAutoDetector != null) { - mAutoDetector.destroy(); - mAutoDetector = null; - } - } - - private void setAutoDetectConnectivityStateInternal(boolean shouldAutoDetect) { - if (shouldAutoDetect) { - if (mAutoDetector == null) { - mAutoDetector = new NetworkChangeNotifierAutoDetect( - new NetworkChangeNotifierAutoDetect.Observer() { - @Override - public void onConnectionTypeChanged(int newConnectionType) { - updateCurrentConnectionType(newConnectionType); - } - }, - mContext); - mCurrentConnectionType = mAutoDetector.getCurrentConnectionType(); - } - } else { - destroyAutoDetector(); - } - } - - /** - * Updates the perceived network state when not auto-detecting changes to connectivity. - * - * @param networkAvailable True if the NetworkChangeNotifier should perceive a "connected" - * state, false implies "disconnected". - */ - @CalledByNative - public static void forceConnectivityState(boolean networkAvailable) { - setAutoDetectConnectivityState(false); - getInstance().forceConnectivityStateInternal(networkAvailable); - } - - private void forceConnectivityStateInternal(boolean forceOnline) { - boolean connectionCurrentlyExists = mCurrentConnectionType != CONNECTION_NONE; - if (connectionCurrentlyExists != forceOnline) { - updateCurrentConnectionType(forceOnline ? CONNECTION_UNKNOWN : CONNECTION_NONE); - } - } - - private void updateCurrentConnectionType(int newConnectionType) { - mCurrentConnectionType = newConnectionType; - notifyObserversOfConnectionTypeChange(newConnectionType); - } - - /** - * Alerts all observers of a connection change. - */ - void notifyObserversOfConnectionTypeChange(int newConnectionType) { - for (Integer nativeChangeNotifier : mNativeChangeNotifiers) { - nativeNotifyConnectionTypeChanged(nativeChangeNotifier, newConnectionType); - } - for (ConnectionTypeObserver observer : mConnectionTypeObservers) { - observer.onConnectionTypeChanged(newConnectionType); - } - } - - /** - * Adds an observer for any connection type changes. - */ - public static void addConnectionTypeObserver(ConnectionTypeObserver observer) { - getInstance().addConnectionTypeObserverInternal(observer); - } - - private void addConnectionTypeObserverInternal(ConnectionTypeObserver observer) { - if (!mConnectionTypeObservers.hasObserver(observer)) { - mConnectionTypeObservers.addObserver(observer); - } - } - - /** - * Removes an observer for any connection type changes. - */ - public static void removeConnectionTypeObserver(ConnectionTypeObserver observer) { - getInstance().removeConnectionTypeObserverInternal(observer); - } - - private void removeConnectionTypeObserverInternal(ConnectionTypeObserver observer) { - mConnectionTypeObservers.removeObserver(observer); - } - - @NativeClassQualifiedName("NetworkChangeNotifierDelegateAndroid") - private native void nativeNotifyConnectionTypeChanged(int nativePtr, int newConnectionType); - - @NativeClassQualifiedName("NetworkChangeNotifierDelegateAndroid") - private native int nativeGetConnectionType(int nativePtr); - - // For testing only. - public static NetworkChangeNotifierAutoDetect getAutoDetectorForTest() { - return getInstance().mAutoDetector; - } - - /** - * Checks if there currently is connectivity. - */ - public static boolean isOnline() { - int connectionType = getInstance().getCurrentConnectionType(); - return connectionType != CONNECTION_UNKNOWN && connectionType != CONNECTION_NONE; - } -} diff --git a/chromium/net/android/java/src/org/chromium/net/NetworkChangeNotifierAutoDetect.java b/chromium/net/android/java/src/org/chromium/net/NetworkChangeNotifierAutoDetect.java deleted file mode 100644 index 038cb3124ac..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/NetworkChangeNotifierAutoDetect.java +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.content.BroadcastReceiver; -import android.content.Context; -import android.content.Intent; -import android.content.IntentFilter; -import android.net.ConnectivityManager; -import android.net.NetworkInfo; -import android.telephony.TelephonyManager; -import android.util.Log; - -import org.chromium.base.ActivityStatus; - -/** - * Used by the NetworkChangeNotifier to listens to platform changes in connectivity. - * Note that use of this class requires that the app have the platform - * ACCESS_NETWORK_STATE permission. - */ -public class NetworkChangeNotifierAutoDetect extends BroadcastReceiver - implements ActivityStatus.StateListener { - - /** Queries the ConnectivityManager for information about the current connection. */ - static class ConnectivityManagerDelegate { - private final ConnectivityManager mConnectivityManager; - - ConnectivityManagerDelegate(Context context) { - mConnectivityManager = - (ConnectivityManager) context.getSystemService(Context.CONNECTIVITY_SERVICE); - } - - // For testing. - ConnectivityManagerDelegate() { - // All the methods below should be overridden. - mConnectivityManager = null; - } - - boolean activeNetworkExists() { - return mConnectivityManager.getActiveNetworkInfo() != null; - } - - boolean isConnected() { - return mConnectivityManager.getActiveNetworkInfo().isConnected(); - } - - int getNetworkType() { - return mConnectivityManager.getActiveNetworkInfo().getType(); - } - - int getNetworkSubtype() { - return mConnectivityManager.getActiveNetworkInfo().getSubtype(); - } - } - - private static final String TAG = "NetworkChangeNotifierAutoDetect"; - - private final NetworkConnectivityIntentFilter mIntentFilter = - new NetworkConnectivityIntentFilter(); - - private final Observer mObserver; - - private final Context mContext; - private ConnectivityManagerDelegate mConnectivityManagerDelegate; - private boolean mRegistered; - private int mConnectionType; - - /** - * Observer notified on the UI thread whenever a new connection type was detected. - */ - public static interface Observer { - public void onConnectionTypeChanged(int newConnectionType); - } - - public NetworkChangeNotifierAutoDetect(Observer observer, Context context) { - mObserver = observer; - mContext = context.getApplicationContext(); - mConnectivityManagerDelegate = new ConnectivityManagerDelegate(context); - mConnectionType = getCurrentConnectionType(); - ActivityStatus.registerStateListener(this); - } - - /** - * Allows overriding the ConnectivityManagerDelegate for tests. - */ - void setConnectivityManagerDelegateForTests(ConnectivityManagerDelegate delegate) { - mConnectivityManagerDelegate = delegate; - } - - public void destroy() { - unregisterReceiver(); - } - - /** - * Register a BroadcastReceiver in the given context. - */ - private void registerReceiver() { - if (!mRegistered) { - mRegistered = true; - mContext.registerReceiver(this, mIntentFilter); - } - } - - /** - * Unregister the BroadcastReceiver in the given context. - */ - private void unregisterReceiver() { - if (mRegistered) { - mRegistered = false; - mContext.unregisterReceiver(this); - } - } - - public int getCurrentConnectionType() { - // Track exactly what type of connection we have. - if (!mConnectivityManagerDelegate.activeNetworkExists() || - !mConnectivityManagerDelegate.isConnected()) { - return NetworkChangeNotifier.CONNECTION_NONE; - } - - switch (mConnectivityManagerDelegate.getNetworkType()) { - case ConnectivityManager.TYPE_ETHERNET: - return NetworkChangeNotifier.CONNECTION_ETHERNET; - case ConnectivityManager.TYPE_WIFI: - return NetworkChangeNotifier.CONNECTION_WIFI; - case ConnectivityManager.TYPE_WIMAX: - return NetworkChangeNotifier.CONNECTION_4G; - case ConnectivityManager.TYPE_MOBILE: - // Use information from TelephonyManager to classify the connection. - switch (mConnectivityManagerDelegate.getNetworkSubtype()) { - case TelephonyManager.NETWORK_TYPE_GPRS: - case TelephonyManager.NETWORK_TYPE_EDGE: - case TelephonyManager.NETWORK_TYPE_CDMA: - case TelephonyManager.NETWORK_TYPE_1xRTT: - case TelephonyManager.NETWORK_TYPE_IDEN: - return NetworkChangeNotifier.CONNECTION_2G; - case TelephonyManager.NETWORK_TYPE_UMTS: - case TelephonyManager.NETWORK_TYPE_EVDO_0: - case TelephonyManager.NETWORK_TYPE_EVDO_A: - case TelephonyManager.NETWORK_TYPE_HSDPA: - case TelephonyManager.NETWORK_TYPE_HSUPA: - case TelephonyManager.NETWORK_TYPE_HSPA: - case TelephonyManager.NETWORK_TYPE_EVDO_B: - case TelephonyManager.NETWORK_TYPE_EHRPD: - case TelephonyManager.NETWORK_TYPE_HSPAP: - return NetworkChangeNotifier.CONNECTION_3G; - case TelephonyManager.NETWORK_TYPE_LTE: - return NetworkChangeNotifier.CONNECTION_4G; - default: - return NetworkChangeNotifier.CONNECTION_UNKNOWN; - } - default: - return NetworkChangeNotifier.CONNECTION_UNKNOWN; - } - } - - // BroadcastReceiver - @Override - public void onReceive(Context context, Intent intent) { - connectionTypeChanged(); - } - - // ActivityStatus.StateListener - @Override - public void onActivityStateChange(int state) { - if (state == ActivityStatus.RESUMED) { - // Note that this also covers the case where the main activity is created. The CREATED - // event is always followed by the RESUMED event. This is a temporary "hack" until - // http://crbug.com/176837 is fixed. The CREATED event can't be used reliably for now - // since its notification is deferred. This means that it can immediately follow a - // DESTROYED/STOPPED/... event which is problematic. - // TODO(pliard): fix http://crbug.com/176837. - connectionTypeChanged(); - registerReceiver(); - } else if (state == ActivityStatus.PAUSED) { - unregisterReceiver(); - } - } - - private void connectionTypeChanged() { - int newConnectionType = getCurrentConnectionType(); - if (newConnectionType == mConnectionType) return; - - mConnectionType = newConnectionType; - Log.d(TAG, "Network connectivity changed, type is: " + mConnectionType); - mObserver.onConnectionTypeChanged(newConnectionType); - } - - private static class NetworkConnectivityIntentFilter extends IntentFilter { - NetworkConnectivityIntentFilter() { - addAction(ConnectivityManager.CONNECTIVITY_ACTION); - } - } -} diff --git a/chromium/net/android/java/src/org/chromium/net/ProxyChangeListener.java b/chromium/net/android/java/src/org/chromium/net/ProxyChangeListener.java deleted file mode 100644 index 9c59bcccf6d..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/ProxyChangeListener.java +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.content.BroadcastReceiver; -import android.content.Context; -import android.content.Intent; -import android.content.IntentFilter; -import android.net.Proxy; - -import org.chromium.base.CalledByNative; -import org.chromium.base.JNINamespace; -import org.chromium.base.NativeClassQualifiedName; - -// This class partners with native ProxyConfigServiceAndroid to listen for -// proxy change notifications from Android. -@JNINamespace("net") -public class ProxyChangeListener { - private static final String TAG = "ProxyChangeListener"; - private static boolean sEnabled = true; - - private int mNativePtr; - private Context mContext; - private ProxyReceiver mProxyReceiver; - private Delegate mDelegate; - - public interface Delegate { - public void proxySettingsChanged(); - } - - private ProxyChangeListener(Context context) { - mContext = context; - } - - public static void setEnabled(boolean enabled) { - sEnabled = enabled; - } - - public void setDelegateForTesting(Delegate delegate) { - mDelegate = delegate; - } - - @CalledByNative - static public ProxyChangeListener create(Context context) { - return new ProxyChangeListener(context); - } - - @CalledByNative - static public String getProperty(String property) { - return System.getProperty(property); - } - - @CalledByNative - public void start(int nativePtr) { - assert mNativePtr == 0; - mNativePtr = nativePtr; - registerReceiver(); - } - - @CalledByNative - public void stop() { - mNativePtr = 0; - unregisterReceiver(); - } - - private class ProxyReceiver extends BroadcastReceiver { - @Override - public void onReceive(Context context, Intent intent) { - if (intent.getAction().equals(Proxy.PROXY_CHANGE_ACTION)) { - proxySettingsChanged(); - } - } - } - - private void proxySettingsChanged() { - if (!sEnabled) { - return; - } - if (mDelegate != null) { - mDelegate.proxySettingsChanged(); - } - if (mNativePtr == 0) { - return; - } - // Note that this code currently runs on a MESSAGE_LOOP_UI thread, but - // the C++ code must run the callbacks on the network thread. - nativeProxySettingsChanged(mNativePtr); - } - - private void registerReceiver() { - if (mProxyReceiver != null) { - return; - } - IntentFilter filter = new IntentFilter(); - filter.addAction(Proxy.PROXY_CHANGE_ACTION); - mProxyReceiver = new ProxyReceiver(); - mContext.getApplicationContext().registerReceiver(mProxyReceiver, filter); - } - - private void unregisterReceiver() { - if (mProxyReceiver == null) { - return; - } - mContext.unregisterReceiver(mProxyReceiver); - mProxyReceiver = null; - } - - /** - * See net/proxy/proxy_config_service_android.cc - */ - @NativeClassQualifiedName("ProxyConfigServiceAndroid::JNIDelegate") - private native void nativeProxySettingsChanged(int nativePtr); -} diff --git a/chromium/net/android/java/src/org/chromium/net/X509Util.java b/chromium/net/android/java/src/org/chromium/net/X509Util.java deleted file mode 100644 index 30007caab17..00000000000 --- a/chromium/net/android/java/src/org/chromium/net/X509Util.java +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.util.Log; - -import org.chromium.net.CertVerifyResultAndroid; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; -import java.security.cert.CertificateExpiredException; -import java.security.cert.CertificateNotYetValidException; -import java.security.cert.CertificateFactory; -import java.security.cert.CertificateParsingException; -import java.security.cert.X509Certificate; -import java.util.List; - -import javax.net.ssl.TrustManager; -import javax.net.ssl.TrustManagerFactory; -import javax.net.ssl.X509TrustManager; - -public class X509Util { - - private static final String TAG = "X509Util"; - - private static CertificateFactory sCertificateFactory; - - private static final String OID_TLS_SERVER_AUTH = "1.3.6.1.5.5.7.3.1"; - private static final String OID_ANY_EKU = "2.5.29.37.0"; - // Server-Gated Cryptography (necessary to support a few legacy issuers): - // Netscape: - private static final String OID_SERVER_GATED_NETSCAPE = "2.16.840.1.113730.4.1"; - // Microsoft: - private static final String OID_SERVER_GATED_MICROSOFT = "1.3.6.1.4.1.311.10.3.3"; - - /** - * Trust manager backed up by the read-only system certificate store. - */ - private static X509TrustManager sDefaultTrustManager; - - /** - * Trust manager backed up by a custom certificate store. We need such manager to plant test - * root CA to the trust store in testing. - */ - private static X509TrustManager sTestTrustManager; - private static KeyStore sTestKeyStore; - - /** - * Lock object used to synchronize all calls that modify or depend on the trust managers. - */ - private static final Object sLock = new Object(); - - /** - * Ensures that the trust managers and certificate factory are initialized. - */ - private static void ensureInitialized() throws CertificateException, - KeyStoreException, NoSuchAlgorithmException { - synchronized(sLock) { - if (sCertificateFactory == null) { - sCertificateFactory = CertificateFactory.getInstance("X.509"); - } - if (sDefaultTrustManager == null) { - sDefaultTrustManager = X509Util.createTrustManager(null); - } - if (sTestKeyStore == null) { - sTestKeyStore = KeyStore.getInstance(KeyStore.getDefaultType()); - try { - sTestKeyStore.load(null); - } catch(IOException e) {} // No IO operation is attempted. - } - if (sTestTrustManager == null) { - sTestTrustManager = X509Util.createTrustManager(sTestKeyStore); - } - } - } - - /** - * Creates a X509TrustManager backed up by the given key store. When null is passed as a key - * store, system default trust store is used. - * @throws KeyStoreException, NoSuchAlgorithmException on error initializing the TrustManager. - */ - private static X509TrustManager createTrustManager(KeyStore keyStore) throws KeyStoreException, - NoSuchAlgorithmException { - String algorithm = TrustManagerFactory.getDefaultAlgorithm(); - TrustManagerFactory tmf = TrustManagerFactory.getInstance(algorithm); - tmf.init(keyStore); - - for (TrustManager tm : tmf.getTrustManagers()) { - if (tm instanceof X509TrustManager) { - return (X509TrustManager) tm; - } - } - return null; - } - - /** - * After each modification of test key store, trust manager has to be generated again. - */ - private static void reloadTestTrustManager() throws KeyStoreException, - NoSuchAlgorithmException { - sTestTrustManager = X509Util.createTrustManager(sTestKeyStore); - } - - /** - * Convert a DER encoded certificate to an X509Certificate. - */ - public static X509Certificate createCertificateFromBytes(byte[] derBytes) throws - CertificateException, KeyStoreException, NoSuchAlgorithmException { - ensureInitialized(); - return (X509Certificate) sCertificateFactory.generateCertificate( - new ByteArrayInputStream(derBytes)); - } - - public static void addTestRootCertificate(byte[] rootCertBytes) throws CertificateException, - KeyStoreException, NoSuchAlgorithmException { - ensureInitialized(); - X509Certificate rootCert = createCertificateFromBytes(rootCertBytes); - synchronized (sLock) { - sTestKeyStore.setCertificateEntry( - "root_cert_" + Integer.toString(sTestKeyStore.size()), rootCert); - reloadTestTrustManager(); - } - } - - public static void clearTestRootCertificates() throws NoSuchAlgorithmException, - CertificateException, KeyStoreException { - ensureInitialized(); - synchronized (sLock) { - try { - sTestKeyStore.load(null); - reloadTestTrustManager(); - } catch (IOException e) {} // No IO operation is attempted. - } - } - - /** - * If an EKU extension is present in the end-entity certificate, it MUST contain either the - * anyEKU or serverAuth or netscapeSGC or Microsoft SGC EKUs. - * - * @return true if there is no EKU extension or if any of the EKU extensions is one of the valid - * OIDs for web server certificates. - * - * TODO(palmer): This can be removed after the equivalent change is made to the Android default - * TrustManager and that change is shipped to a large majority of Android users. - */ - static boolean verifyKeyUsage(X509Certificate certificate) throws CertificateException { - List<String> ekuOids; - try { - ekuOids = certificate.getExtendedKeyUsage(); - } catch (NullPointerException e) { - // getExtendedKeyUsage() can crash due to an Android platform bug. This probably - // happens when the EKU extension data is malformed so return false here. - // See http://crbug.com/233610 - return false; - } - if (ekuOids == null) - return true; - - for (String ekuOid : ekuOids) { - if (ekuOid.equals(OID_TLS_SERVER_AUTH) || - ekuOid.equals(OID_ANY_EKU) || - ekuOid.equals(OID_SERVER_GATED_NETSCAPE) || - ekuOid.equals(OID_SERVER_GATED_MICROSOFT)) { - return true; - } - } - - return false; - } - - public static int verifyServerCertificates(byte[][] certChain, String authType) - throws KeyStoreException, NoSuchAlgorithmException { - if (certChain == null || certChain.length == 0 || certChain[0] == null) { - throw new IllegalArgumentException("Expected non-null and non-empty certificate " + - "chain passed as |certChain|. |certChain|=" + certChain); - } - - try { - ensureInitialized(); - } catch (CertificateException e) { - return CertVerifyResultAndroid.VERIFY_FAILED; - } - - X509Certificate[] serverCertificates = new X509Certificate[certChain.length]; - try { - for (int i = 0; i < certChain.length; ++i) { - serverCertificates[i] = createCertificateFromBytes(certChain[i]); - } - } catch (CertificateException e) { - return CertVerifyResultAndroid.VERIFY_UNABLE_TO_PARSE; - } - - // Expired and not yet valid certificates would be rejected by the trust managers, but the - // trust managers report all certificate errors using the general CertificateException. In - // order to get more granular error information, cert validity time range is being checked - // separately. - try { - serverCertificates[0].checkValidity(); - if (!verifyKeyUsage(serverCertificates[0])) - return CertVerifyResultAndroid.VERIFY_INCORRECT_KEY_USAGE; - } catch (CertificateExpiredException e) { - return CertVerifyResultAndroid.VERIFY_EXPIRED; - } catch (CertificateNotYetValidException e) { - return CertVerifyResultAndroid.VERIFY_NOT_YET_VALID; - } catch (CertificateException e) { - return CertVerifyResultAndroid.VERIFY_FAILED; - } - - synchronized (sLock) { - try { - sDefaultTrustManager.checkServerTrusted(serverCertificates, authType); - return CertVerifyResultAndroid.VERIFY_OK; - } catch (CertificateException eDefaultManager) { - try { - sTestTrustManager.checkServerTrusted(serverCertificates, authType); - return CertVerifyResultAndroid.VERIFY_OK; - } catch (CertificateException eTestManager) { - // Neither of the trust managers confirms the validity of the certificate chain, - // log the error message returned by the system trust manager. - Log.i(TAG, "Failed to validate the certificate chain, error: " + - eDefaultManager.getMessage()); - return CertVerifyResultAndroid.VERIFY_NO_TRUSTED_ROOT; - } - } - } - } -} diff --git a/chromium/net/android/javatests/src/org/chromium/net/AndroidKeyStoreTestUtil.java b/chromium/net/android/javatests/src/org/chromium/net/AndroidKeyStoreTestUtil.java deleted file mode 100644 index 460dc50cabb..00000000000 --- a/chromium/net/android/javatests/src/org/chromium/net/AndroidKeyStoreTestUtil.java +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.os.Build; -import android.util.Log; - -import java.security.PrivateKey; -import java.security.PrivateKey; -import java.security.Signature; -import java.security.KeyFactory; -import java.security.spec.KeySpec; -import java.security.spec.PKCS8EncodedKeySpec; -import java.security.KeyStoreException; -import java.security.spec.InvalidKeySpecException; -import java.security.NoSuchAlgorithmException; - -import org.chromium.base.CalledByNative; -import org.chromium.base.JNINamespace; -import org.chromium.net.PrivateKeyType; - -@JNINamespace("net::android") -public class AndroidKeyStoreTestUtil { - - private static final String TAG = "AndroidKeyStoreTestUtil"; - - /** - * Called from native code to create a PrivateKey object from its - * encoded PKCS#8 representation. - * @param type The key type, accoding to PrivateKeyType. - * @return new PrivateKey handle, or null in case of error. - */ - @CalledByNative - public static PrivateKey createPrivateKeyFromPKCS8(int type, - byte[] encoded_key) { - String algorithm = null; - switch (type) { - case PrivateKeyType.RSA: - algorithm = "RSA"; - break; - case PrivateKeyType.DSA: - algorithm = "DSA"; - break; - case PrivateKeyType.ECDSA: - algorithm = "EC"; - break; - default: - return null; - } - - try { - KeyFactory factory = KeyFactory.getInstance(algorithm); - KeySpec ks = new PKCS8EncodedKeySpec(encoded_key); - PrivateKey key = factory.generatePrivate(ks); - return key; - - } catch (NoSuchAlgorithmException e) { - Log.e(TAG, "Could not create " + algorithm + " factory instance!"); - return null; - } catch (InvalidKeySpecException e) { - Log.e(TAG, "Could not load " + algorithm + " private key from bytes!"); - return null; - } - } -} diff --git a/chromium/net/android/javatests/src/org/chromium/net/AndroidProxySelectorTest.java b/chromium/net/android/javatests/src/org/chromium/net/AndroidProxySelectorTest.java deleted file mode 100644 index c705f69be37..00000000000 --- a/chromium/net/android/javatests/src/org/chromium/net/AndroidProxySelectorTest.java +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -/** - * Test suite for Android's default ProxySelector implementation. The purpose of these tests - * is to check that the behaviour of the ProxySelector implementation matches what we have - * implemented in net/proxy/proxy_config_service_android.cc. - * - * IMPORTANT: These test cases are generated from net/android/tools/proxy_test_cases.py, so if any - * of these tests fail, please be sure to edit that file and regenerate the test cases here and also - * in net/proxy/proxy_config_service_android_unittests.cc if required. - */ - -package org.chromium.net; - -import android.test.InstrumentationTestCase; -import android.test.suitebuilder.annotation.SmallTest; - -import java.net.Proxy; -import java.net.ProxySelector; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.List; -import java.util.Properties; - -import org.chromium.base.test.util.Feature; - -public class AndroidProxySelectorTest extends InstrumentationTestCase { - Properties mProperties; - - public AndroidProxySelectorTest() { - // Start with a clean slate in case there is a system proxy configured. - mProperties = new Properties(); - } - - @Override - public void setUp() { - System.setProperties(mProperties); - } - - static String toString(Proxy proxy) { - if (proxy == Proxy.NO_PROXY) - return "DIRECT"; - // java.net.Proxy only knows about http and socks proxies. - Proxy.Type type = proxy.type(); - switch (type) { - case HTTP: return "PROXY " + proxy.address().toString(); - case SOCKS: return "SOCKS5 " + proxy.address().toString(); - case DIRECT: return "DIRECT"; - default: - // If a new proxy type is supported in future, add a case to match it. - fail("Unknown proxy type" + type); - return "unknown://"; - } - } - - static String toString(List<Proxy> proxies) { - StringBuilder builder = new StringBuilder(); - for (Proxy proxy : proxies) { - if (builder.length() > 0) - builder.append(';'); - builder.append(toString(proxy)); - } - return builder.toString(); - } - - static void checkMapping(String url, String expected) throws URISyntaxException { - URI uri = new URI(url); - List<Proxy> proxies = ProxySelector.getDefault().select(uri); - assertEquals("Mapping", expected, toString(proxies)); - } - - /** - * Test direct mapping when no proxy defined. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testNoProxy() throws Exception { - checkMapping("ftp://example.com/", "DIRECT"); - checkMapping("http://example.com/", "DIRECT"); - checkMapping("https://example.com/", "DIRECT"); - } - - /** - * Test http.proxyHost and http.proxyPort works. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpProxyHostAndPort() throws Exception { - System.setProperty("http.proxyHost", "httpproxy.com"); - System.setProperty("http.proxyPort", "8080"); - checkMapping("ftp://example.com/", "DIRECT"); - checkMapping("http://example.com/", "PROXY httpproxy.com:8080"); - checkMapping("https://example.com/", "DIRECT"); - } - - /** - * We should get the default port (80) for proxied hosts. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpProxyHostOnly() throws Exception { - System.setProperty("http.proxyHost", "httpproxy.com"); - checkMapping("ftp://example.com/", "DIRECT"); - checkMapping("http://example.com/", "PROXY httpproxy.com:80"); - checkMapping("https://example.com/", "DIRECT"); - } - - /** - * http.proxyPort only should not result in any hosts being proxied. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpProxyPortOnly() throws Exception { - System.setProperty("http.proxyPort", "8080"); - checkMapping("ftp://example.com/", "DIRECT"); - checkMapping("http://example.com/", "DIRECT"); - checkMapping("https://example.com/", "DIRECT"); - } - - /** - * Test that HTTP non proxy hosts are mapped correctly - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpNonProxyHosts1() throws Exception { - System.setProperty("http.nonProxyHosts", "slashdot.org"); - System.setProperty("http.proxyHost", "httpproxy.com"); - System.setProperty("http.proxyPort", "8080"); - checkMapping("http://example.com/", "PROXY httpproxy.com:8080"); - checkMapping("http://slashdot.org/", "DIRECT"); - } - - /** - * Test that | pattern works. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpNonProxyHosts2() throws Exception { - System.setProperty("http.nonProxyHosts", "slashdot.org|freecode.net"); - System.setProperty("http.proxyHost", "httpproxy.com"); - System.setProperty("http.proxyPort", "8080"); - checkMapping("http://example.com/", "PROXY httpproxy.com:8080"); - checkMapping("http://freecode.net/", "DIRECT"); - checkMapping("http://slashdot.org/", "DIRECT"); - } - - /** - * Test that * pattern works. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpNonProxyHosts3() throws Exception { - System.setProperty("http.nonProxyHosts", "*example.com"); - System.setProperty("http.proxyHost", "httpproxy.com"); - System.setProperty("http.proxyPort", "8080"); - checkMapping("http://example.com/", "DIRECT"); - checkMapping("http://slashdot.org/", "PROXY httpproxy.com:8080"); - checkMapping("http://www.example.com/", "DIRECT"); - } - - /** - * Test that FTP non proxy hosts are mapped correctly - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testFtpNonProxyHosts() throws Exception { - System.setProperty("ftp.nonProxyHosts", "slashdot.org"); - System.setProperty("ftp.proxyHost", "httpproxy.com"); - System.setProperty("ftp.proxyPort", "8080"); - checkMapping("ftp://example.com/", "PROXY httpproxy.com:8080"); - checkMapping("http://example.com/", "DIRECT"); - } - - /** - * Test ftp.proxyHost and ftp.proxyPort works. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testFtpProxyHostAndPort() throws Exception { - System.setProperty("ftp.proxyHost", "httpproxy.com"); - System.setProperty("ftp.proxyPort", "8080"); - checkMapping("ftp://example.com/", "PROXY httpproxy.com:8080"); - checkMapping("http://example.com/", "DIRECT"); - checkMapping("https://example.com/", "DIRECT"); - } - - /** - * Test ftp.proxyHost and default port. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testFtpProxyHostOnly() throws Exception { - System.setProperty("ftp.proxyHost", "httpproxy.com"); - checkMapping("ftp://example.com/", "PROXY httpproxy.com:80"); - checkMapping("http://example.com/", "DIRECT"); - checkMapping("https://example.com/", "DIRECT"); - } - - /** - * Test https.proxyHost and https.proxyPort works. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpsProxyHostAndPort() throws Exception { - System.setProperty("https.proxyHost", "httpproxy.com"); - System.setProperty("https.proxyPort", "8080"); - checkMapping("ftp://example.com/", "DIRECT"); - checkMapping("http://example.com/", "DIRECT"); - checkMapping("https://example.com/", "PROXY httpproxy.com:8080"); - } - - /** - * Default http proxy is used if a scheme-specific one is not found. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testDefaultProxyExplictPort() throws Exception { - System.setProperty("ftp.proxyHost", "httpproxy.com"); - System.setProperty("ftp.proxyPort", "8080"); - System.setProperty("proxyHost", "defaultproxy.com"); - System.setProperty("proxyPort", "8080"); - checkMapping("ftp://example.com/", "PROXY httpproxy.com:8080"); - checkMapping("http://example.com/", "PROXY defaultproxy.com:8080"); - checkMapping("https://example.com/", "PROXY defaultproxy.com:8080"); - } - - /** - * SOCKS proxy is used if scheme-specific one is not found. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testFallbackToSocks() throws Exception { - System.setProperty("http.proxyHost", "defaultproxy.com"); - System.setProperty("socksProxyHost", "socksproxy.com"); - checkMapping("ftp://example.com", "SOCKS5 socksproxy.com:1080"); - checkMapping("http://example.com/", "PROXY defaultproxy.com:80"); - checkMapping("https://example.com/", "SOCKS5 socksproxy.com:1080"); - } - - /** - * SOCKS proxy port is used if specified - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testSocksExplicitPort() throws Exception { - System.setProperty("socksProxyHost", "socksproxy.com"); - System.setProperty("socksProxyPort", "9000"); - checkMapping("http://example.com/", "SOCKS5 socksproxy.com:9000"); - } - - /** - * SOCKS proxy is ignored if default HTTP proxy defined. - * - * @throws Exception - */ - @SmallTest - @Feature({"AndroidWebView"}) - public void testHttpProxySupercedesSocks() throws Exception { - System.setProperty("proxyHost", "defaultproxy.com"); - System.setProperty("socksProxyHost", "socksproxy.com"); - System.setProperty("socksProxyPort", "9000"); - checkMapping("http://example.com/", "PROXY defaultproxy.com:80"); - } -} - diff --git a/chromium/net/android/javatests/src/org/chromium/net/NetErrorsTest.java b/chromium/net/android/javatests/src/org/chromium/net/NetErrorsTest.java deleted file mode 100644 index 57885b2efe0..00000000000 --- a/chromium/net/android/javatests/src/org/chromium/net/NetErrorsTest.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -/** - * Tests to verify that NetError.java is created succesfully. - */ - -package org.chromium.net; - -import android.test.InstrumentationTestCase; -import android.test.suitebuilder.annotation.SmallTest; - -import org.chromium.base.test.util.Feature; - -public class NetErrorsTest extends InstrumentationTestCase { - // These are manually copied and should be kept in sync with net_error_list.h. - private static int IO_PENDING_ERROR = -1; - private static int FAILED_ERROR = -2; - - /** - * Test whether we can include NetError.java and call to static integers defined in the file. - * - * @throws Exception - */ - @SmallTest - @Feature({"Android-AppBase"}) - public void testExampleErrorDefined() throws Exception { - assertEquals(IO_PENDING_ERROR, NetError.ERR_IO_PENDING); - assertEquals(FAILED_ERROR, NetError.ERR_FAILED); - } -} diff --git a/chromium/net/android/javatests/src/org/chromium/net/NetworkChangeNotifierTest.java b/chromium/net/android/javatests/src/org/chromium/net/NetworkChangeNotifierTest.java deleted file mode 100644 index b52d184de67..00000000000 --- a/chromium/net/android/javatests/src/org/chromium/net/NetworkChangeNotifierTest.java +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.content.Context; -import android.content.Intent; -import android.net.ConnectivityManager; -import android.telephony.TelephonyManager; -import android.test.InstrumentationTestCase; -import android.test.UiThreadTest; -import android.test.suitebuilder.annotation.MediumTest; - -import org.chromium.base.ActivityStatus; -import org.chromium.base.test.util.Feature; - -public class NetworkChangeNotifierTest extends InstrumentationTestCase { - /** - * Listens for alerts fired by the NetworkChangeNotifier when network status changes. - */ - private static class NetworkChangeNotifierTestObserver - implements NetworkChangeNotifier.ConnectionTypeObserver { - private boolean mReceivedNotification = false; - - @Override - public void onConnectionTypeChanged(int connectionType) { - mReceivedNotification = true; - } - - public boolean hasReceivedNotification() { - return mReceivedNotification; - } - - public void resetHasReceivedNotification() { - mReceivedNotification = false; - } - } - - /** - * Mocks out calls to the ConnectivityManager. - */ - class MockConnectivityManagerDelegate - extends NetworkChangeNotifierAutoDetect.ConnectivityManagerDelegate { - private boolean mActiveNetworkExists; - private int mNetworkType; - private int mNetworkSubtype; - - @Override - boolean activeNetworkExists() { - return mActiveNetworkExists; - } - - @Override - boolean isConnected() { - return getNetworkType() != NetworkChangeNotifier.CONNECTION_NONE; - } - - void setActiveNetworkExists(boolean networkExists) { - mActiveNetworkExists = networkExists; - } - - @Override - int getNetworkType() { - return mNetworkType; - } - - void setNetworkType(int networkType) { - mNetworkType = networkType; - } - - @Override - int getNetworkSubtype() { - return mNetworkSubtype; - } - - void setNetworkSubtype(int networkSubtype) { - mNetworkSubtype = networkSubtype; - } - } - - /** - * Tests that when Chrome gets an intent indicating a change in network connectivity, it sends a - * notification to Java observers. - */ - @UiThreadTest - @MediumTest - @Feature({"Android-AppBase"}) - public void testNetworkChangeNotifierJavaObservers() throws InterruptedException { - // Create a new notifier that doesn't have a native-side counterpart. - Context context = getInstrumentation().getTargetContext(); - NetworkChangeNotifier.resetInstanceForTests(context); - - NetworkChangeNotifier.setAutoDetectConnectivityState(true); - NetworkChangeNotifierAutoDetect receiver = NetworkChangeNotifier.getAutoDetectorForTest(); - assertTrue(receiver != null); - - MockConnectivityManagerDelegate connectivityDelegate = - new MockConnectivityManagerDelegate(); - connectivityDelegate.setActiveNetworkExists(true); - connectivityDelegate.setNetworkType(NetworkChangeNotifier.CONNECTION_UNKNOWN); - connectivityDelegate.setNetworkSubtype(TelephonyManager.NETWORK_TYPE_UNKNOWN); - receiver.setConnectivityManagerDelegateForTests(connectivityDelegate); - - // Initialize the NetworkChangeNotifier with a connection. - Intent connectivityIntent = new Intent(ConnectivityManager.CONNECTIVITY_ACTION); - receiver.onReceive(getInstrumentation().getTargetContext(), connectivityIntent); - - // We shouldn't be re-notified if the connection hasn't actually changed. - NetworkChangeNotifierTestObserver observer = new NetworkChangeNotifierTestObserver(); - NetworkChangeNotifier.addConnectionTypeObserver(observer); - receiver.onReceive(getInstrumentation().getTargetContext(), connectivityIntent); - assertFalse(observer.hasReceivedNotification()); - - // Mimic that connectivity has been lost and ensure that Chrome notifies our observer. - connectivityDelegate.setActiveNetworkExists(false); - connectivityDelegate.setNetworkType(NetworkChangeNotifier.CONNECTION_NONE); - Intent noConnectivityIntent = new Intent(ConnectivityManager.CONNECTIVITY_ACTION); - receiver.onReceive(getInstrumentation().getTargetContext(), noConnectivityIntent); - assertTrue(observer.hasReceivedNotification()); - - observer.resetHasReceivedNotification(); - // Pretend we got moved to the background. - receiver.onActivityStateChange(ActivityStatus.PAUSED); - // Change the state. - connectivityDelegate.setActiveNetworkExists(true); - connectivityDelegate.setNetworkType(NetworkChangeNotifier.CONNECTION_WIFI); - // The NetworkChangeNotifierAutoDetect doesn't receive any notification while we are in the - // background, but when we get back to the foreground the state changed should be detected - // and a notification sent. - receiver.onActivityStateChange(ActivityStatus.RESUMED); - assertTrue(observer.hasReceivedNotification()); - } -} diff --git a/chromium/net/android/javatests/src/org/chromium/net/X509UtilTest.java b/chromium/net/android/javatests/src/org/chromium/net/X509UtilTest.java deleted file mode 100644 index 7dcbc685cd4..00000000000 --- a/chromium/net/android/javatests/src/org/chromium/net/X509UtilTest.java +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net; - -import android.content.Context; -import android.content.Intent; -import android.net.ConnectivityManager; -import android.telephony.TelephonyManager; -import android.test.UiThreadTest; -import android.test.suitebuilder.annotation.MediumTest; -import android.test.InstrumentationTestCase; -import android.util.Base64; - -import java.io.BufferedReader; -import java.io.FileReader; -import java.io.IOException; -import java.io.RandomAccessFile; -import java.security.GeneralSecurityException; -import java.security.cert.CertificateException; -import java.security.cert.CertificateParsingException; -import java.security.cert.X509Certificate; -import java.util.Arrays; - -import org.chromium.base.PathUtils; - -/** - * Tests for org.chromium.net.X509Util. - */ -public class X509UtilTest extends InstrumentationTestCase { - private static final String CERTS_DIRECTORY = - PathUtils.getExternalStorageDirectory() + "/net/data/ssl/certificates/"; - private static final String BAD_EKU_TEST_ROOT = "eku-test-root.pem"; - private static final String CRITICAL_CODE_SIGNING_EE = "crit-codeSigning-chain.pem"; - private static final String NON_CRITICAL_CODE_SIGNING_EE = "non-crit-codeSigning-chain.pem"; - private static final String WEB_CLIENT_AUTH_EE = "invalid_key_usage_cert.der"; - private static final String OK_CERT = "ok_cert.pem"; - private static final String GOOD_ROOT_CA = "root_ca_cert.pem"; - - private static final String BEGIN_MARKER = "-----BEGIN CERTIFICATE-----"; - private static final String END_MARKER = "-----END CERTIFICATE-----"; - - private static byte[] pemToDer(String pemPathname) throws IOException { - BufferedReader reader = new BufferedReader(new FileReader(pemPathname)); - StringBuilder builder = new StringBuilder(); - - // Skip past leading junk lines, if any. - String line = reader.readLine(); - while (line != null && !line.contains(BEGIN_MARKER)) line = reader.readLine(); - - // Then skip the BEGIN_MARKER itself, if present. - while (line != null && line.contains(BEGIN_MARKER)) line = reader.readLine(); - - // Now gather the data lines into the builder. - while (line != null && !line.contains(END_MARKER)) { - builder.append(line.trim()); - line = reader.readLine(); - } - - reader.close(); - return Base64.decode(builder.toString(), Base64.DEFAULT); - } - - private static byte[] readFileBytes(String pathname) throws IOException { - RandomAccessFile file = new RandomAccessFile(pathname, "r"); - byte[] bytes = new byte[(int) file.length()]; - int bytesRead = file.read(bytes); - if (bytesRead != bytes.length) - return Arrays.copyOfRange(bytes, 0, bytesRead); - return bytes; - } - - @MediumTest - public void testEkusVerified() throws GeneralSecurityException, IOException { - X509Util.addTestRootCertificate(pemToDer(CERTS_DIRECTORY + BAD_EKU_TEST_ROOT)); - X509Util.addTestRootCertificate(pemToDer(CERTS_DIRECTORY + GOOD_ROOT_CA)); - - assertFalse(X509Util.verifyKeyUsage( - X509Util.createCertificateFromBytes( - pemToDer(CERTS_DIRECTORY + CRITICAL_CODE_SIGNING_EE)))); - - assertFalse(X509Util.verifyKeyUsage( - X509Util.createCertificateFromBytes( - pemToDer(CERTS_DIRECTORY + NON_CRITICAL_CODE_SIGNING_EE)))); - - assertFalse(X509Util.verifyKeyUsage( - X509Util.createCertificateFromBytes( - readFileBytes(CERTS_DIRECTORY + WEB_CLIENT_AUTH_EE)))); - - assertTrue(X509Util.verifyKeyUsage( - X509Util.createCertificateFromBytes( - pemToDer(CERTS_DIRECTORY + OK_CERT)))); - - try { - X509Util.clearTestRootCertificates(); - } catch (Exception e) { - fail("Could not clear test root certificates: " + e.toString()); - } - } -} - diff --git a/chromium/net/android/keystore_openssl.cc b/chromium/net/android/keystore_openssl.cc index cd55ece3336..5ad847344aa 100644 --- a/chromium/net/android/keystore_openssl.cc +++ b/chromium/net/android/keystore_openssl.cc @@ -35,7 +35,7 @@ // // Generally speaking, OpenSSL provides many different ways to sign // digests. This code doesn't support all these cases, only the ones that -// are required to sign the MAC during the OpenSSL handshake for TLS. +// are required to sign the digest during the OpenSSL handshake for TLS. // // The OpenSSL EVP_PKEY type is a generic wrapper around key pairs. // Internally, it can hold a pointer to a RSA, DSA or ECDSA structure, @@ -106,7 +106,6 @@ typedef crypto::ScopedOpenSSL<RSA, RSA_free> ScopedRSA; typedef crypto::ScopedOpenSSL<DSA, DSA_free> ScopedDSA; typedef crypto::ScopedOpenSSL<EC_KEY, EC_KEY_free> ScopedEC_KEY; typedef crypto::ScopedOpenSSL<EC_GROUP, EC_GROUP_free> ScopedEC_GROUP; -typedef crypto::ScopedOpenSSL<X509_SIG, X509_SIG_free> ScopedX509_SIG; // Custom RSA_METHOD that uses the platform APIs. // Note that for now, only signing through RSA_sign() is really supported. @@ -133,14 +132,60 @@ int RsaMethodPubDec(int flen, return -1; } +// See RSA_eay_private_encrypt in +// third_party/openssl/openssl/crypto/rsa/rsa_eay.c for the default +// implementation of this function. int RsaMethodPrivEnc(int flen, const unsigned char *from, unsigned char *to, RSA *rsa, int padding) { - NOTIMPLEMENTED(); - RSAerr(RSA_F_RSA_PRIVATE_ENCRYPT, RSA_R_RSA_OPERATIONS_NOT_SUPPORTED); - return -1; + DCHECK_EQ(RSA_PKCS1_PADDING, padding); + if (padding != RSA_PKCS1_PADDING) { + // TODO(davidben): If we need to, we can implement RSA_NO_PADDING + // by using javax.crypto.Cipher and picking either the + // "RSA/ECB/NoPadding" or "RSA/ECB/PKCS1Padding" transformation as + // appropriate. I believe support for both of these was added in + // the same Android version as the "NONEwithRSA" + // java.security.Signature algorithm, so the same version checks + // for GetRsaLegacyKey should work. + RSAerr(RSA_F_RSA_PRIVATE_ENCRYPT, RSA_R_UNKNOWN_PADDING_TYPE); + return -1; + } + + // Retrieve private key JNI reference. + jobject private_key = reinterpret_cast<jobject>(RSA_get_app_data(rsa)); + if (!private_key) { + LOG(WARNING) << "Null JNI reference passed to RsaMethodPrivEnc!"; + RSAerr(RSA_F_RSA_PRIVATE_ENCRYPT, ERR_R_INTERNAL_ERROR); + return -1; + } + + base::StringPiece from_piece(reinterpret_cast<const char*>(from), flen); + std::vector<uint8> result; + // For RSA keys, this function behaves as RSA_private_encrypt with + // PKCS#1 padding. + if (!RawSignDigestWithPrivateKey(private_key, from_piece, &result)) { + LOG(WARNING) << "Could not sign message in RsaMethodPrivEnc!"; + RSAerr(RSA_F_RSA_PRIVATE_ENCRYPT, ERR_R_INTERNAL_ERROR); + return -1; + } + + size_t expected_size = static_cast<size_t>(RSA_size(rsa)); + if (result.size() > expected_size) { + LOG(ERROR) << "RSA Signature size mismatch, actual: " + << result.size() << ", expected <= " << expected_size; + RSAerr(RSA_F_RSA_PRIVATE_ENCRYPT, ERR_R_INTERNAL_ERROR); + return -1; + } + + // Copy result to OpenSSL-provided buffer. RawSignDigestWithPrivateKey + // should pad with leading 0s, but if it doesn't, pad the result. + size_t zero_pad = expected_size - result.size(); + memset(to, 0, zero_pad); + memcpy(to + zero_pad, &result[0], result.size()); + + return expected_size; } int RsaMethodPrivDec(int flen, @@ -154,8 +199,6 @@ int RsaMethodPrivDec(int flen, } int RsaMethodInit(RSA* rsa) { - // Required to ensure that RsaMethodSign will be called. - rsa->flags |= RSA_FLAG_SIGN_VER; return 0; } @@ -173,99 +216,6 @@ int RsaMethodFinish(RSA* rsa) { return 0; } -// Although these parameters are, per OpenSSL, named |message| and -// |message_len|, RsaMethodSign is actually passed a message digest, -// not the original message. -int RsaMethodSign(int type, - const unsigned char* message, - unsigned int message_len, - unsigned char* signature, - unsigned int* signature_len, - const RSA* rsa) { - // Retrieve private key JNI reference. - jobject private_key = reinterpret_cast<jobject>(RSA_get_app_data(rsa)); - if (!private_key) { - LOG(WARNING) << "Null JNI reference passed to RsaMethodSign!"; - return 0; - } - - // See RSA_sign in third_party/openssl/openssl/crypto/rsa/rsa_sign.c. - base::StringPiece message_piece; - std::vector<uint8> buffer; // To store |message| wrapped in a DigestInfo. - if (type == NID_md5_sha1) { - // For TLS < 1.2, sign just |message|. - message_piece.set(message, static_cast<size_t>(message_len)); - } else { - // For TLS 1.2, wrap |message| in a PKCS #1 DigestInfo before signing. - ScopedX509_SIG sig(X509_SIG_new()); - if (!sig.get()) - return 0; - if (X509_ALGOR_set0(sig.get()->algor, - OBJ_nid2obj(type), V_ASN1_NULL, 0) != 1) { - return 0; - } - if (sig.get()->algor->algorithm == NULL) { - RSAerr(RSA_F_RSA_SIGN, RSA_R_UNKNOWN_ALGORITHM_TYPE); - return 0; - } - if (sig.get()->algor->algorithm->length == 0) { - RSAerr(RSA_F_RSA_SIGN, - RSA_R_THE_ASN1_OBJECT_IDENTIFIER_IS_NOT_KNOWN_FOR_THIS_MD); - return 0; - } - if (ASN1_OCTET_STRING_set(sig.get()->digest, message, message_len) != 1) - return 0; - - int len = i2d_X509_SIG(sig.get(), NULL); - if (len < 0) { - LOG(WARNING) << "Couldn't encode X509_SIG structure"; - return 0; - } - buffer.resize(len); - // OpenSSL takes a pointer to a pointer so it can kindly increment - // it for you. - unsigned char* p = &buffer[0]; - len = i2d_X509_SIG(sig.get(), &p); - if (len < 0) { - LOG(WARNING) << "Couldn't encode X509_SIG structure"; - return 0; - } - - message_piece.set(&buffer[0], static_cast<size_t>(len)); - } - - // Sanity-check the size. - // - // TODO(davidben): Do we need to do this? OpenSSL does, but - // RawSignDigestWithPrivateKey does error on sufficiently large - // input. However, it doesn't take the padding into account. - size_t expected_size = static_cast<size_t>(RSA_size(rsa)); - if (message_piece.size() > expected_size - RSA_PKCS1_PADDING_SIZE) { - RSAerr(RSA_F_RSA_SIGN, RSA_R_DIGEST_TOO_BIG_FOR_RSA_KEY); - return 0; - } - - // Sign |message_piece| with the private key through JNI. - std::vector<uint8> result; - - if (!RawSignDigestWithPrivateKey( - private_key, message_piece, &result)) { - LOG(WARNING) << "Could not sign message in RsaMethodSign!"; - return 0; - } - - if (result.size() > expected_size) { - LOG(ERROR) << "RSA Signature size mismatch, actual: " - << result.size() << ", expected <= " << expected_size; - return 0; - } - - // Copy result to OpenSSL-provided buffer - memcpy(signature, &result[0], result.size()); - *signature_len = static_cast<unsigned int>(result.size()); - return 1; -} - const RSA_METHOD android_rsa_method = { /* .name = */ "Android signing-only RSA method", /* .rsa_pub_enc = */ RsaMethodPubEnc, @@ -281,7 +231,7 @@ const RSA_METHOD android_rsa_method = { // it's not valid for the certificate. /* .flags = */ RSA_METHOD_FLAG_NO_CHECK, /* .app_data = */ NULL, - /* .rsa_sign = */ RsaMethodSign, + /* .rsa_sign = */ NULL, /* .rsa_verify = */ NULL, /* .rsa_keygen = */ NULL, }; diff --git a/chromium/net/android/net_jni_registrar.cc b/chromium/net/android/net_jni_registrar.cc deleted file mode 100644 index a6e09b65080..00000000000 --- a/chromium/net/android/net_jni_registrar.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/android/net_jni_registrar.h" - -#include "base/basictypes.h" -#include "base/android/jni_android.h" -#include "base/android/jni_registrar.h" -#include "net/android/gurl_utils.h" -#include "net/android/keystore.h" -#include "net/android/network_change_notifier_android.h" -#include "net/android/network_library.h" -#include "net/proxy/proxy_config_service_android.h" - -namespace net { -namespace android { - -static base::android::RegistrationMethod kNetRegisteredMethods[] = { - { "AndroidKeyStore", net::android::RegisterKeyStore }, - { "AndroidNetworkLibrary", net::android::RegisterNetworkLibrary }, - { "GURLUtils", net::RegisterGURLUtils }, - { "NetworkChangeNotifierAndroid", - net::NetworkChangeNotifierAndroid::Register }, - { "ProxyConfigService", net::ProxyConfigServiceAndroid::Register }, -}; - -bool RegisterJni(JNIEnv* env) { - return base::android::RegisterNativeMethods( - env, kNetRegisteredMethods, arraysize(kNetRegisteredMethods)); -} - -} // namespace android -} // namespace net diff --git a/chromium/net/android/net_jni_registrar.h b/chromium/net/android/net_jni_registrar.h deleted file mode 100644 index 2b45fb26d07..00000000000 --- a/chromium/net/android/net_jni_registrar.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_ANDROID_NET_JNI_REGISTRAR_H_ -#define NET_ANDROID_NET_JNI_REGISTRAR_H_ - -#include <jni.h> - -#include "net/base/net_export.h" - -namespace net { -namespace android { - -// Register all JNI bindings necessary for net. -NET_EXPORT bool RegisterJni(JNIEnv* env); - -} // namespace android -} // namespace net - -#endif // NET_ANDROID_NET_JNI_REGISTRAR_H_ diff --git a/chromium/net/base/address_family.h b/chromium/net/base/address_family.h index 75beb29d649..07adfaa2f4d 100644 --- a/chromium/net/base/address_family.h +++ b/chromium/net/base/address_family.h @@ -24,6 +24,8 @@ enum { HOST_RESOLVER_LOOPBACK_ONLY = 1 << 1, // Indicate the address family was set because no IPv6 support was detected. HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6 = 1 << 2, + // The resolver should only invoke getaddrinfo, not DnsClient. + HOST_RESOLVER_SYSTEM_ONLY = 1 << 3 }; typedef int HostResolverFlags; diff --git a/chromium/net/base/cache_type.h b/chromium/net/base/cache_type.h index 69b5646ece7..06aa826b8fb 100644 --- a/chromium/net/base/cache_type.h +++ b/chromium/net/base/cache_type.h @@ -13,7 +13,8 @@ enum CacheType { MEMORY_CACHE, // Data is stored only in memory. MEDIA_CACHE, // Optimized to handle media files. APP_CACHE, // Backing store for an AppCache. - SHADER_CACHE // Backing store for the GL shader cache. + SHADER_CACHE, // Backing store for the GL shader cache. + PNACL_CACHE, // Backing store the PNaCl translation cache }; // The types of disk cache backend, only used at backend instantiation. diff --git a/chromium/net/base/escape_unittest.cc b/chromium/net/base/escape_unittest.cc index bed49a5e1d5..e7e435c08ef 100644 --- a/chromium/net/base/escape_unittest.cc +++ b/chromium/net/base/escape_unittest.cc @@ -337,10 +337,12 @@ TEST(EscapeTest, UnescapeAndDecodeUTF8URLComponent) { TEST(EscapeTest, AdjustOffset) { const AdjustOffsetCase adjust_cases[] = { - {"", 0, std::string::npos}, + {"", 0, 0}, + {"", 1, std::string::npos}, {"test", 0, 0}, {"test", 2, 2}, - {"test", 4, std::string::npos}, + {"test", 4, 4}, + {"test", 5, std::string::npos}, {"test", std::string::npos, std::string::npos}, {"%2dtest", 6, 4}, {"%2dtest", 2, std::string::npos}, diff --git a/chromium/net/base/file_stream.cc b/chromium/net/base/file_stream.cc index 85adaece3cd..fd2eb4af3f9 100644 --- a/chromium/net/base/file_stream.cc +++ b/chromium/net/base/file_stream.cc @@ -93,6 +93,19 @@ int FileStream::OpenSync(const base::FilePath& path, int open_flags) { return context_->OpenSync(path, open_flags_); } +int FileStream::Close(const CompletionCallback& callback) { + DCHECK(is_async()); + context_->CloseAsync(callback); + return ERR_IO_PENDING; +} + +int FileStream::CloseSync() { + DCHECK(!is_async()); + base::ThreadRestrictions::AssertIOAllowed(); + context_->CloseSync(); + return OK; +} + bool FileStream::IsOpen() const { return context_->file() != base::kInvalidPlatformFileValue; } diff --git a/chromium/net/base/file_stream.h b/chromium/net/base/file_stream.h index 0fb3fb26569..9fe274759c7 100644 --- a/chromium/net/base/file_stream.h +++ b/chromium/net/base/file_stream.h @@ -81,6 +81,16 @@ class NET_EXPORT FileStream { // automatically closed when FileStream is destructed. virtual int OpenSync(const base::FilePath& path, int open_flags); + // Returns ERR_IO_PENDING and closes the file asynchronously, calling + // |callback| when done. + // It is invalid to request any asynchronous operations while there is an + // in-flight asynchronous operation. + virtual int Close(const CompletionCallback& callback); + + // Closes the file immediately and returns OK. If the file is open + // asynchronously, Close(const CompletionCallback&) should be used instead. + virtual int CloseSync(); + // Returns true if Open succeeded and Close has not been called. virtual bool IsOpen() const; diff --git a/chromium/net/base/file_stream_context.cc b/chromium/net/base/file_stream_context.cc index abc058a9ca8..2e774752045 100644 --- a/chromium/net/base/file_stream_context.cc +++ b/chromium/net/base/file_stream_context.cc @@ -105,6 +105,21 @@ void FileStream::Context::CloseSync() { } } +void FileStream::Context::CloseAsync(const CompletionCallback& callback) { + DCHECK(!async_in_progress_); + const bool posted = base::PostTaskAndReplyWithResult( + task_runner_.get(), + FROM_HERE, + base::Bind(&Context::CloseFileImpl, base::Unretained(this)), + base::Bind(&Context::ProcessAsyncResult, + base::Unretained(this), + IntToInt64(callback), + FILE_ERROR_SOURCE_CLOSE)); + DCHECK(posted); + + async_in_progress_ = true; +} + void FileStream::Context::SeekAsync(Whence whence, int64 offset, const Int64CompletionCallback& callback) { @@ -159,11 +174,6 @@ void FileStream::Context::RecordError(const IOResult& result, return; } - // The following check is against incorrect use or bug. File descriptor - // shouldn't ever be closed outside of FileStream while it still tries to do - // something with it. - DCHECK_NE(result.result, ERR_INVALID_HANDLE); - if (!orphaned_) { bound_net_log_.AddEvent( NetLog::TYPE_FILE_STREAM_ERROR, diff --git a/chromium/net/base/file_stream_context.h b/chromium/net/base/file_stream_context.h index 15c25bb8d4a..3ce8a5b6cdc 100644 --- a/chromium/net/base/file_stream_context.h +++ b/chromium/net/base/file_stream_context.h @@ -109,6 +109,8 @@ class FileStream::Context { void CloseSync(); + void CloseAsync(const CompletionCallback& callback); + void SeekAsync(Whence whence, int64 offset, const Int64CompletionCallback& callback); @@ -191,6 +193,9 @@ class FileStream::Context { // Flushes all data written to the stream. IOResult FlushFileImpl(); + // Closes the file. + IOResult CloseFileImpl(); + #if defined(OS_WIN) void IOCompletionIsPending(const CompletionCallback& callback, IOBuffer* buf); diff --git a/chromium/net/base/file_stream_context_posix.cc b/chromium/net/base/file_stream_context_posix.cc index 1ef5be5fc16..6e6bc6eaa69 100644 --- a/chromium/net/base/file_stream_context_posix.cc +++ b/chromium/net/base/file_stream_context_posix.cc @@ -184,4 +184,13 @@ FileStream::Context::IOResult FileStream::Context::WriteFileImpl( return IOResult(res, 0); } +FileStream::Context::IOResult FileStream::Context::CloseFileImpl() { + bool success = base::ClosePlatformFile(file_); + file_ = base::kInvalidPlatformFileValue; + if (!success) + return IOResult::FromOSError(errno); + + return IOResult(OK, 0); +} + } // namespace net diff --git a/chromium/net/base/file_stream_context_win.cc b/chromium/net/base/file_stream_context_win.cc index 5fb30859e75..4a666c70a5b 100644 --- a/chromium/net/base/file_stream_context_win.cc +++ b/chromium/net/base/file_stream_context_win.cc @@ -196,6 +196,15 @@ FileStream::Context::IOResult FileStream::Context::FlushFileImpl() { return IOResult::FromOSError(GetLastError()); } +FileStream::Context::IOResult FileStream::Context::CloseFileImpl() { + bool success = base::ClosePlatformFile(file_); + file_ = base::kInvalidPlatformFileValue; + if (success) + return IOResult(OK, 0); + + return IOResult::FromOSError(GetLastError()); +} + void FileStream::Context::IOCompletionIsPending( const CompletionCallback& callback, IOBuffer* buf) { diff --git a/chromium/net/base/file_stream_metrics.cc b/chromium/net/base/file_stream_metrics.cc index 4dc0576325f..bff7174a40b 100644 --- a/chromium/net/base/file_stream_metrics.cc +++ b/chromium/net/base/file_stream_metrics.cc @@ -19,7 +19,8 @@ const char* FileErrorSourceStrings[] = { "SEEK", "FLUSH", "SET_EOF", - "GET_SIZE" + "GET_SIZE", + "CLOSE" }; COMPILE_ASSERT(ARRAYSIZE_UNSAFE(FileErrorSourceStrings) == @@ -34,11 +35,6 @@ void RecordFileErrorTypeCount(FileErrorSource source) { } // namespace void RecordFileError(int error, FileErrorSource source, bool record) { - LOG(ERROR) << " " << __FUNCTION__ << "()" - << " error = " << error - << " source = " << source - << " record = " << record; - if (!record) return; @@ -88,6 +84,12 @@ void RecordFileError(int error, FileErrorSource source, bool record) { max_bucket); break; + case FILE_ERROR_SOURCE_CLOSE: + UMA_HISTOGRAM_ENUMERATION("Net.FileError_Close", error, max_error); + UMA_HISTOGRAM_ENUMERATION("Net.FileErrorRange_Close", bucket, + max_bucket); + break; + default: break; } diff --git a/chromium/net/base/file_stream_metrics.h b/chromium/net/base/file_stream_metrics.h index 14988aad82e..4dab3ddc64b 100644 --- a/chromium/net/base/file_stream_metrics.h +++ b/chromium/net/base/file_stream_metrics.h @@ -17,6 +17,7 @@ enum FileErrorSource { FILE_ERROR_SOURCE_FLUSH, FILE_ERROR_SOURCE_SET_EOF, FILE_ERROR_SOURCE_GET_SIZE, + FILE_ERROR_SOURCE_CLOSE, FILE_ERROR_SOURCE_COUNT, }; diff --git a/chromium/net/base/file_stream_unittest.cc b/chromium/net/base/file_stream_unittest.cc index 4be58b738e8..c76f3d939ac 100644 --- a/chromium/net/base/file_stream_unittest.cc +++ b/chromium/net/base/file_stream_unittest.cc @@ -8,8 +8,10 @@ #include "base/callback.h" #include "base/file_util.h" #include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_proxy.h" #include "base/path_service.h" #include "base/platform_file.h" +#include "base/run_loop.h" #include "base/synchronization/waitable_event.h" #include "base/test/test_timeouts.h" #include "net/base/capturing_net_log.h" @@ -46,6 +48,9 @@ class FileStreamTest : public PlatformTest { virtual void TearDown() { EXPECT_TRUE(base::DeleteFile(temp_file_path_, false)); + // FileStreamContexts must be asynchronously closed on the file task runner + // before they can be deleted. Pump the RunLoop to avoid leaks. + base::RunLoop().RunUntilIdle(); PlatformTest::TearDown(); } @@ -60,7 +65,7 @@ namespace { TEST_F(FileStreamTest, BasicOpenClose) { base::PlatformFile file = base::kInvalidPlatformFileValue; { - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int rv = stream.OpenSync(temp_file_path(), base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ); EXPECT_EQ(OK, rv); @@ -73,6 +78,66 @@ TEST_F(FileStreamTest, BasicOpenClose) { EXPECT_FALSE(base::GetPlatformFileInfo(file, &info)); } +TEST_F(FileStreamTest, BasicOpenExplicitClose) { + base::PlatformFile file = base::kInvalidPlatformFileValue; + FileStream stream(NULL); + int rv = stream.OpenSync(temp_file_path(), + base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ); + EXPECT_EQ(OK, rv); + EXPECT_TRUE(stream.IsOpen()); + file = stream.GetPlatformFileForTesting(); + EXPECT_NE(base::kInvalidPlatformFileValue, file); + EXPECT_EQ(OK, stream.CloseSync()); + EXPECT_FALSE(stream.IsOpen()); + base::PlatformFileInfo info; + // The file should be closed. + EXPECT_FALSE(base::GetPlatformFileInfo(file, &info)); +} + +TEST_F(FileStreamTest, AsyncOpenExplicitClose) { + base::PlatformFile file = base::kInvalidPlatformFileValue; + TestCompletionCallback callback; + FileStream stream(NULL); + int flags = base::PLATFORM_FILE_OPEN | + base::PLATFORM_FILE_READ | + base::PLATFORM_FILE_ASYNC; + int rv = stream.Open(temp_file_path(), flags, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(stream.IsOpen()); + file = stream.GetPlatformFileForTesting(); + EXPECT_EQ(ERR_IO_PENDING, stream.Close(callback.callback())); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_FALSE(stream.IsOpen()); + base::PlatformFileInfo info; + // The file should be closed. + EXPECT_FALSE(base::GetPlatformFileInfo(file, &info)); +} + +TEST_F(FileStreamTest, AsyncOpenExplicitCloseOrphaned) { + base::PlatformFile file = base::kInvalidPlatformFileValue; + TestCompletionCallback callback; + base::PlatformFileInfo info; + scoped_ptr<FileStream> stream(new FileStream( + NULL, base::MessageLoopProxy::current())); + int flags = base::PLATFORM_FILE_OPEN | + base::PLATFORM_FILE_READ | + base::PLATFORM_FILE_ASYNC; + int rv = stream->Open(temp_file_path(), flags, callback.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(OK, callback.WaitForResult()); + EXPECT_TRUE(stream->IsOpen()); + file = stream->GetPlatformFileForTesting(); + EXPECT_EQ(ERR_IO_PENDING, stream->Close(callback.callback())); + stream.reset(); + // File isn't actually closed yet. + EXPECT_TRUE(base::GetPlatformFileInfo(file, &info)); + base::RunLoop runloop; + runloop.RunUntilIdle(); + // The file should now be closed, though the callback has not been called. + EXPECT_FALSE(base::GetPlatformFileInfo(file, &info)); +} + TEST_F(FileStreamTest, FileHandleNotLeftOpen) { bool created = false; ASSERT_EQ(kTestDataSize, @@ -83,7 +148,8 @@ TEST_F(FileStreamTest, FileHandleNotLeftOpen) { { // Seek to the beginning of the file and read. - FileStream read_stream(file, flags, NULL); + FileStream read_stream(file, flags, NULL, + base::MessageLoopProxy::current()); EXPECT_TRUE(read_stream.IsOpen()); } @@ -105,7 +171,8 @@ TEST_F(FileStreamTest, UseFileHandle) { temp_file_path(), flags, &created, NULL); // Seek to the beginning of the file and read. - scoped_ptr<FileStream> read_stream(new FileStream(file, flags, NULL)); + scoped_ptr<FileStream> read_stream( + new FileStream(file, flags, NULL, base::MessageLoopProxy::current())); ASSERT_EQ(0, read_stream->SeekSync(FROM_BEGIN, 0)); ASSERT_EQ(kTestDataSize, read_stream->Available()); // Read into buffer and compare. @@ -120,7 +187,8 @@ TEST_F(FileStreamTest, UseFileHandle) { flags = base::PLATFORM_FILE_OPEN_ALWAYS | base::PLATFORM_FILE_WRITE; file = base::CreatePlatformFile(temp_file_path(), flags, &created, NULL); - scoped_ptr<FileStream> write_stream(new FileStream(file, flags, NULL)); + scoped_ptr<FileStream> write_stream( + new FileStream(file, flags, NULL, base::MessageLoopProxy::current())); ASSERT_EQ(0, write_stream->SeekSync(FROM_BEGIN, 0)); ASSERT_EQ(kTestDataSize, write_stream->WriteSync(kTestData, kTestDataSize)); @@ -133,7 +201,7 @@ TEST_F(FileStreamTest, UseFileHandle) { } TEST_F(FileStreamTest, UseClosedStream) { - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); EXPECT_FALSE(stream.IsOpen()); @@ -156,7 +224,7 @@ TEST_F(FileStreamTest, BasicRead) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ; int rv = stream.OpenSync(temp_file_path(), flags); @@ -186,7 +254,7 @@ TEST_F(FileStreamTest, AsyncRead) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_ASYNC; @@ -221,7 +289,8 @@ TEST_F(FileStreamTest, AsyncRead_EarlyDelete) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_ASYNC; @@ -239,7 +308,7 @@ TEST_F(FileStreamTest, AsyncRead_EarlyDelete) { if (rv < 0) { EXPECT_EQ(ERR_IO_PENDING, rv); // The callback should not be called if the request is cancelled. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); EXPECT_FALSE(callback.have_result()); } else { EXPECT_EQ(std::string(kTestData, rv), std::string(buf->data(), rv)); @@ -251,7 +320,7 @@ TEST_F(FileStreamTest, BasicRead_FromOffset) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ; int rv = stream.OpenSync(temp_file_path(), flags); @@ -286,7 +355,7 @@ TEST_F(FileStreamTest, AsyncRead_FromOffset) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_ASYNC; @@ -324,7 +393,7 @@ TEST_F(FileStreamTest, AsyncRead_FromOffset) { } TEST_F(FileStreamTest, SeekAround) { - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ; int rv = stream.OpenSync(temp_file_path(), flags); @@ -347,7 +416,7 @@ TEST_F(FileStreamTest, SeekAround) { } TEST_F(FileStreamTest, AsyncSeekAround) { - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_ASYNC | base::PLATFORM_FILE_READ; @@ -383,7 +452,8 @@ TEST_F(FileStreamTest, AsyncSeekAround) { } TEST_F(FileStreamTest, BasicWrite) { - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_CREATE_ALWAYS | base::PLATFORM_FILE_WRITE; int rv = stream->OpenSync(temp_file_path(), flags); @@ -404,7 +474,7 @@ TEST_F(FileStreamTest, BasicWrite) { } TEST_F(FileStreamTest, AsyncWrite) { - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_CREATE_ALWAYS | base::PLATFORM_FILE_WRITE | base::PLATFORM_FILE_ASYNC; @@ -440,7 +510,8 @@ TEST_F(FileStreamTest, AsyncWrite) { } TEST_F(FileStreamTest, AsyncWrite_EarlyDelete) { - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_CREATE_ALWAYS | base::PLATFORM_FILE_WRITE | base::PLATFORM_FILE_ASYNC; @@ -460,7 +531,7 @@ TEST_F(FileStreamTest, AsyncWrite_EarlyDelete) { if (rv < 0) { EXPECT_EQ(ERR_IO_PENDING, rv); // The callback should not be called if the request is cancelled. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); EXPECT_FALSE(callback.have_result()); } else { ok = file_util::GetFileSize(temp_file_path(), &file_size); @@ -470,7 +541,8 @@ TEST_F(FileStreamTest, AsyncWrite_EarlyDelete) { } TEST_F(FileStreamTest, BasicWrite_FromOffset) { - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_WRITE; int rv = stream->OpenSync(temp_file_path(), flags); @@ -499,7 +571,7 @@ TEST_F(FileStreamTest, AsyncWrite_FromOffset) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - FileStream stream(NULL); + FileStream stream(NULL, base::MessageLoopProxy::current()); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_WRITE | base::PLATFORM_FILE_ASYNC; @@ -541,7 +613,8 @@ TEST_F(FileStreamTest, BasicReadWrite) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_WRITE; @@ -580,7 +653,8 @@ TEST_F(FileStreamTest, BasicWriteRead) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_WRITE; @@ -628,7 +702,8 @@ TEST_F(FileStreamTest, BasicAsyncReadWrite) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_WRITE | @@ -687,7 +762,8 @@ TEST_F(FileStreamTest, BasicAsyncWriteRead) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_WRITE | @@ -778,7 +854,7 @@ class TestWriteReadCompletionCallback { DCHECK(!waiting_for_result_); while (!have_result_) { waiting_for_result_ = true; - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); waiting_for_result_ = false; } have_result_ = false; // auto-reset for next callback @@ -853,7 +929,8 @@ TEST_F(FileStreamTest, AsyncWriteRead) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_WRITE | @@ -911,7 +988,7 @@ class TestWriteCloseCompletionCallback { DCHECK(!waiting_for_result_); while (!have_result_) { waiting_for_result_ = true; - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); waiting_for_result_ = false; } have_result_ = false; // auto-reset for next callback @@ -962,7 +1039,8 @@ TEST_F(FileStreamTest, AsyncWriteClose) { bool ok = file_util::GetFileSize(temp_file_path(), &file_size); EXPECT_TRUE(ok); - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_WRITE | @@ -999,7 +1077,8 @@ TEST_F(FileStreamTest, AsyncWriteClose) { TEST_F(FileStreamTest, Truncate) { int flags = base::PLATFORM_FILE_CREATE_ALWAYS | base::PLATFORM_FILE_WRITE; - scoped_ptr<FileStream> write_stream(new FileStream(NULL)); + scoped_ptr<FileStream> write_stream( + new FileStream(NULL, base::MessageLoopProxy::current())); ASSERT_EQ(OK, write_stream->OpenSync(temp_file_path(), flags)); // Write some data to the file. @@ -1017,13 +1096,14 @@ TEST_F(FileStreamTest, Truncate) { // Read in the contents and make sure we get back what we expected. std::string read_contents; - EXPECT_TRUE(file_util::ReadFileToString(temp_file_path(), &read_contents)); + EXPECT_TRUE(base::ReadFileToString(temp_file_path(), &read_contents)); EXPECT_EQ("01230123", read_contents); } TEST_F(FileStreamTest, AsyncOpenAndDelete) { - scoped_ptr<FileStream> stream(new FileStream(NULL)); + scoped_ptr<FileStream> stream( + new FileStream(NULL, base::MessageLoopProxy::current())); int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_WRITE | base::PLATFORM_FILE_ASYNC; @@ -1035,46 +1115,62 @@ TEST_F(FileStreamTest, AsyncOpenAndDelete) { // complete. Should be safe. stream.reset(); // open_callback won't be called. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); EXPECT_FALSE(open_callback.have_result()); } // Verify that async Write() errors are mapped correctly. TEST_F(FileStreamTest, AsyncWriteError) { - scoped_ptr<FileStream> stream(new FileStream(NULL)); - int flags = base::PLATFORM_FILE_CREATE_ALWAYS | - base::PLATFORM_FILE_WRITE | + // Try opening file as read-only and then writing to it using FileStream. + base::PlatformFile file = base::CreatePlatformFile( + temp_file_path(), + base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | + base::PLATFORM_FILE_ASYNC, + NULL, + NULL); + ASSERT_NE(base::kInvalidPlatformFileValue, file); + + int flags = base::PLATFORM_FILE_CREATE_ALWAYS | base::PLATFORM_FILE_WRITE | base::PLATFORM_FILE_ASYNC; - TestCompletionCallback callback; - int rv = stream->Open(temp_file_path(), flags, callback.callback()); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_EQ(OK, callback.WaitForResult()); + scoped_ptr<FileStream> stream( + new FileStream(file, flags, NULL, base::MessageLoopProxy::current())); - // Try passing NULL buffer to Write() and check that it fails. - scoped_refptr<IOBuffer> buf = new WrappedIOBuffer(NULL); - rv = stream->Write(buf.get(), 1, callback.callback()); + scoped_refptr<IOBuffer> buf = new IOBuffer(1); + buf->data()[0] = 0; + + TestCompletionCallback callback; + int rv = stream->Write(buf.get(), 1, callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_LT(rv, 0); + + base::ClosePlatformFile(file); } // Verify that async Read() errors are mapped correctly. TEST_F(FileStreamTest, AsyncReadError) { - scoped_ptr<FileStream> stream(new FileStream(NULL)); - int flags = base::PLATFORM_FILE_OPEN | - base::PLATFORM_FILE_READ | + // Try opening file for write and then reading from it using FileStream. + base::PlatformFile file = base::CreatePlatformFile( + temp_file_path(), + base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_WRITE | + base::PLATFORM_FILE_ASYNC, + NULL, + NULL); + ASSERT_NE(base::kInvalidPlatformFileValue, file); + + int flags = base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ | base::PLATFORM_FILE_ASYNC; - TestCompletionCallback callback; - int rv = stream->Open(temp_file_path(), flags, callback.callback()); - EXPECT_EQ(ERR_IO_PENDING, rv); - EXPECT_EQ(OK, callback.WaitForResult()); + scoped_ptr<FileStream> stream( + new FileStream(file, flags, NULL, base::MessageLoopProxy::current())); - // Try passing NULL buffer to Read() and check that it fails. - scoped_refptr<IOBuffer> buf = new WrappedIOBuffer(NULL); - rv = stream->Read(buf.get(), 1, callback.callback()); + scoped_refptr<IOBuffer> buf = new IOBuffer(1); + TestCompletionCallback callback; + int rv = stream->Read(buf.get(), 1, callback.callback()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_LT(rv, 0); + + base::ClosePlatformFile(file); } } // namespace diff --git a/chromium/net/base/gzip_filter_unittest.cc b/chromium/net/base/gzip_filter_unittest.cc index be98bae00d8..2ac4a3fc13b 100644 --- a/chromium/net/base/gzip_filter_unittest.cc +++ b/chromium/net/base/gzip_filter_unittest.cc @@ -69,7 +69,7 @@ class GZipUnitTest : public PlatformTest { file_path = file_path.AppendASCII("google.txt"); // Read data from the file into buffer. - ASSERT_TRUE(file_util::ReadFileToString(file_path, &source_buffer_)); + ASSERT_TRUE(base::ReadFileToString(file_path, &source_buffer_)); // Encode the data with deflate deflate_encode_buffer_ = new char[kDefaultBufferSize]; diff --git a/chromium/net/base/linked_hash_map.h b/chromium/net/base/linked_hash_map.h index d08b68d2592..7948647df05 100644 --- a/chromium/net/base/linked_hash_map.h +++ b/chromium/net/base/linked_hash_map.h @@ -146,7 +146,7 @@ class linked_hash_map { std::pair<typename MapType::iterator, typename MapType::iterator> eq_range = map_.equal_range(key); - return make_pair(eq_range.first->second, eq_range.second->second); + return std::make_pair(eq_range.first->second, eq_range.second->second); } std::pair<const_iterator, const_iterator> equal_range( @@ -159,13 +159,13 @@ class linked_hash_map { const const_iterator& end_iter = eq_range.second != map_.end() ? eq_range.second->second : end(); - return make_pair(start_iter, end_iter); + return std::make_pair(start_iter, end_iter); } // Returns the value mapped to key, or an inserted iterator to that position // in the map. Value& operator[](const key_type& key) { - return (*((this->insert(make_pair(key, Value()))).first)).second; + return (*((this->insert(std::make_pair(key, Value()))).first)).second; } // Inserts an element into the map @@ -174,7 +174,7 @@ class linked_hash_map { // return a pair with an iterator to it, and false indicating that we // didn't insert anything. typename MapType::iterator found = map_.find(pair.first); - if (found != map_.end()) return make_pair(found->second, false); + if (found != map_.end()) return std::make_pair(found->second, false); // Otherwise, insert into the list first. list_.push_back(pair); @@ -184,10 +184,10 @@ class linked_hash_map { typename ListType::iterator last = list_.end(); --last; - CHECK(map_.insert(make_pair(pair.first, last)).second) + CHECK(map_.insert(std::make_pair(pair.first, last)).second) << "Map and list are inconsistent"; - return make_pair(last, true); + return std::make_pair(last, true); } size_type size() const { diff --git a/chromium/net/base/load_flags_list.h b/chromium/net/base/load_flags_list.h index 968c122496d..2d004886d1c 100644 --- a/chromium/net/base/load_flags_list.h +++ b/chromium/net/base/load_flags_list.h @@ -119,3 +119,8 @@ LOAD_FLAG(DO_NOT_PROMPT_FOR_LOGIN, 1 << 26) // explicit user action. This can be used as a hint to treat the // request with higher priority. LOAD_FLAG(MAYBE_USER_GESTURE, 1 << 27) + +// Indicates that the username:password portion of the URL should not +// be honored, but that other forms of authority may be used. +LOAD_FLAG(DO_NOT_USE_EMBEDDED_IDENTITY, 1 << 28) + diff --git a/chromium/net/base/mime_sniffer.cc b/chromium/net/base/mime_sniffer.cc index cc83c7c3bda..ef2e27030bb 100644 --- a/chromium/net/base/mime_sniffer.cc +++ b/chromium/net/base/mime_sniffer.cc @@ -236,7 +236,7 @@ static const MagicNumber kExtraMagicNumbers[] = { MAGIC_NUMBER("video/3gpp", "....ftyp3g") MAGIC_NUMBER("video/3gpp", "....ftypavcl") MAGIC_NUMBER("video/mp4", "....ftyp") - MAGIC_NUMBER("video/quicktime", "MOVI") + MAGIC_NUMBER("video/quicktime", "....moov") MAGIC_NUMBER("application/x-shockwave-flash", "CWS") MAGIC_NUMBER("application/x-shockwave-flash", "FWS") MAGIC_NUMBER("video/x-flv", "FLV") @@ -804,8 +804,7 @@ bool ShouldSniffMimeType(const GURL& url, const std::string& mime_type) { UMASnifferHistogramGet("mime_sniffer.ShouldSniffMimeType2", 3); } bool sniffable_scheme = url.is_empty() || - url.SchemeIs("http") || - url.SchemeIs("https") || + url.SchemeIsHTTPOrHTTPS() || url.SchemeIs("ftp") || #if defined(OS_ANDROID) url.SchemeIs("content") || diff --git a/chromium/net/base/mime_util.cc b/chromium/net/base/mime_util.cc index 3a0310c53f3..70dd3530df2 100644 --- a/chromium/net/base/mime_util.cc +++ b/chromium/net/base/mime_util.cc @@ -10,6 +10,7 @@ #include "base/containers/hash_tables.h" #include "base/lazy_instance.h" #include "base/logging.h" +#include "base/stl_util.h" #include "base/strings/string_split.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" @@ -222,10 +223,9 @@ bool MimeUtil::GetMimeTypeFromExtensionHelper( base::FilePath path_ext(ext); const string ext_narrow_str = path_ext.AsUTF8Unsafe(); - const char* mime_type; - - mime_type = FindMimeType(primary_mappings, arraysize(primary_mappings), - ext_narrow_str.c_str()); + const char* mime_type = FindMimeType(primary_mappings, + arraysize(primary_mappings), + ext_narrow_str.c_str()); if (mime_type) { *result = mime_type; return true; @@ -454,7 +454,7 @@ void MimeUtil::InitializeMimeTypeMaps() { non_image_map_.insert(supported_javascript_types[i]); for (size_t i = 0; i < arraysize(common_media_types); ++i) non_image_map_.insert(common_media_types[i]); -#if defined(GOOGLE_CHROME_BUILD) || defined(USE_PROPRIETARY_CODECS) +#if defined(USE_PROPRIETARY_CODECS) for (size_t i = 0; i < arraysize(proprietary_media_types); ++i) non_image_map_.insert(proprietary_media_types[i]); #endif @@ -462,7 +462,7 @@ void MimeUtil::InitializeMimeTypeMaps() { // Initialize the supported media types. for (size_t i = 0; i < arraysize(common_media_types); ++i) media_map_.insert(common_media_types[i]); -#if defined(GOOGLE_CHROME_BUILD) || defined(USE_PROPRIETARY_CODECS) +#if defined(USE_PROPRIETARY_CODECS) for (size_t i = 0; i < arraysize(proprietary_media_types); ++i) media_map_.insert(proprietary_media_types[i]); #endif @@ -472,7 +472,7 @@ void MimeUtil::InitializeMimeTypeMaps() { for (size_t i = 0; i < arraysize(common_media_codecs); ++i) codecs_map_.insert(common_media_codecs[i]); -#if defined(GOOGLE_CHROME_BUILD) || defined(USE_PROPRIETARY_CODECS) +#if defined(USE_PROPRIETARY_CODECS) for (size_t i = 0; i < arraysize(proprietary_media_codecs); ++i) codecs_map_.insert(proprietary_media_codecs[i]); #endif @@ -542,11 +542,9 @@ bool MatchesMimeTypeParameters(const std::string& mime_type_pattern, sort(pattern_parameters.begin(), pattern_parameters.end()); sort(test_parameters.begin(), test_parameters.end()); - std::vector<std::string> difference; - std::set_difference(pattern_parameters.begin(), pattern_parameters.end(), - test_parameters.begin(), test_parameters.end(), - std::inserter(difference, difference.begin())); - + std::vector<std::string> difference = + base::STLSetDifference<std::vector<std::string> >(pattern_parameters, + test_parameters); return difference.size() == 0; } return true; diff --git a/chromium/net/base/net_error_list.h b/chromium/net/base/net_error_list.h index 5ec421ed7b5..4379d540018 100644 --- a/chromium/net/base/net_error_list.h +++ b/chromium/net/base/net_error_list.h @@ -300,6 +300,9 @@ NET_ERROR(WS_THROTTLE_QUEUE_TOO_LARGE, -154) // was rejected. NET_ERROR(TOO_MANY_SOCKET_STREAMS, -155) +// The SSL server certificate changed in a renegotiation. +NET_ERROR(SSL_SERVER_CERT_CHANGED, -156) + // Certificate error codes // // The values of certificate error codes must be consecutive. diff --git a/chromium/net/base/net_log_event_type_list.h b/chromium/net/base/net_log_event_type_list.h index 01d9bf1e66d..837de5f3ef4 100644 --- a/chromium/net/base/net_log_event_type_list.h +++ b/chromium/net/base/net_log_event_type_list.h @@ -50,6 +50,11 @@ EVENT_TYPE(HOST_RESOLVER_IMPL) // // { // "host": <Hostname associated with the request>, +// "address_family": <The address family to restrict results to> +// "allow_cached_response": <Whether it is ok to return a result from +// the host cache> +// "is_speculative": <Whether this request was started by the DNS +// prefetcher> // "source_dependency": <Source id, if any, of what created the request>, // } // @@ -1296,6 +1301,15 @@ EVENT_TYPE(QUIC_SESSION_PACKET_RECEIVED) // as a base-10 string.>, // "size": <The size of the packet in bytes> // } +EVENT_TYPE(QUIC_SESSION_PACKET_RETRANSMITTED) + +// Session retransmitted a QUIC packet. +// { +// "old_packet_sequence_number": <The old packet's full 64-bit sequence +// number, as a base-10 string.>, +// "new_packet_sequence_number": <The new packet's full 64-bit sequence +// number, as a base-10 string.>, +// } EVENT_TYPE(QUIC_SESSION_PACKET_SENT) // Session received a QUIC packet header for a valid packet. @@ -1435,6 +1449,55 @@ EVENT_TYPE(QUIC_SESSION_CONNECTION_CLOSE_FRAME_RECEIVED) // } EVENT_TYPE(QUIC_SESSION_CONNECTION_CLOSE_FRAME_SENT) +// Session received a public reset packet. +// { +// } +EVENT_TYPE(QUIC_SESSION_PUBLIC_RESET_PACKET_RECEIVED) + +// Session received a version negotiation packet. +// { +// "versions": <List of QUIC versions supported by the server>, +// } +EVENT_TYPE(QUIC_SESSION_VERSION_NEGOTIATION_PACKET_RECEIVED) + +// Session sucessfully negotiated QUIC version number. +// { +// "version": <String of QUIC version negotiated with the server>, +// } +EVENT_TYPE(QUIC_SESSION_VERSION_NEGOTIATED) + +// Session revived a QUIC packet packet via FEC. +// { +// "guid": <The 64-bit GUID for this connection, as a base-10 string>, +// "public_flags": <The public flags set for this packet>, +// "packet_sequence_number": <The packet's full 64-bit sequence number, +// as a base-10 string.>, +// "private_flags": <The private flags set for this packet>, +// "fec_group": <The FEC group of this packet>, +// } +EVENT_TYPE(QUIC_SESSION_PACKET_HEADER_REVIVED) + +// Session received a crypto handshake message. +// { +// "quic_crypto_handshake_message": <The human readable dump of the message +// contents> +// } +EVENT_TYPE(QUIC_SESSION_CRYPTO_HANDSHAKE_MESSAGE_RECEIVED) + +// Session sent a crypto handshake message. +// { +// "quic_crypto_handshake_message": <The human readable dump of the message +// contents> +// } +EVENT_TYPE(QUIC_SESSION_CRYPTO_HANDSHAKE_MESSAGE_SENT) + +// Session was closed, either remotely or by the peer. +// { +// "quic_error": <QuicErrorCode which caused the connection to be closed>, +// "from_peer": <True if the peer closed the connection> +// } +EVENT_TYPE(QUIC_SESSION_CLOSED) + // ------------------------------------------------------------------------ // QuicHttpStream // ------------------------------------------------------------------------ @@ -1575,13 +1638,8 @@ EVENT_TYPE(NETWORK_CHANGED) // { // "nameservers": <List of name server IPs>, // "search": <List of domain suffixes>, -// "append_to_multi_label_name": <See DnsConfig>, -// "ndots": <See DnsConfig>, -// "timeout": <See DnsConfig>, -// "attempts": <See DnsConfig>, -// "rotate": <See DnsConfig>, -// "edns0": <See DnsConfig>, -// "num_hosts": <Number of entries in the HOSTS file> +// "num_hosts": <Number of entries in the HOSTS file>, +// <other>: <See DnsConfig> // } EVENT_TYPE(DNS_CONFIG_CHANGED) diff --git a/chromium/net/base/net_log_logger_unittest.cc b/chromium/net/base/net_log_logger_unittest.cc index c4ee98a28c9..3dd6915915d 100644 --- a/chromium/net/base/net_log_logger_unittest.cc +++ b/chromium/net/base/net_log_logger_unittest.cc @@ -35,7 +35,7 @@ TEST_F(NetLogLoggerTest, GeneratesValidJSONForNoEvents) { } std::string input; - ASSERT_TRUE(file_util::ReadFileToString(log_path_, &input)); + ASSERT_TRUE(base::ReadFileToString(log_path_, &input)); base::JSONReader reader; scoped_ptr<base::Value> root(reader.ReadToValue(input)); @@ -67,7 +67,7 @@ TEST_F(NetLogLoggerTest, GeneratesValidJSONWithOneEvent) { } std::string input; - ASSERT_TRUE(file_util::ReadFileToString(log_path_, &input)); + ASSERT_TRUE(base::ReadFileToString(log_path_, &input)); base::JSONReader reader; scoped_ptr<base::Value> root(reader.ReadToValue(input)); @@ -102,7 +102,7 @@ TEST_F(NetLogLoggerTest, GeneratesValidJSONWithMultipleEvents) { } std::string input; - ASSERT_TRUE(file_util::ReadFileToString(log_path_, &input)); + ASSERT_TRUE(base::ReadFileToString(log_path_, &input)); base::JSONReader reader; scoped_ptr<base::Value> root(reader.ReadToValue(input)); diff --git a/chromium/net/base/net_util.cc b/chromium/net/base/net_util.cc index 958e3c3bca8..dd0826c6a54 100644 --- a/chromium/net/base/net_util.cc +++ b/chromium/net/base/net_util.cc @@ -83,6 +83,8 @@ namespace net { namespace { +typedef std::vector<size_t> Offsets; + // what we prepend to get a file URL static const base::FilePath::CharType kFileURLPrefix[] = FILE_PATH_LITERAL("file:///"); @@ -445,8 +447,7 @@ bool IDNToUnicodeOneComponent(const base::char16* comp, } // Clamps the offsets in |offsets_for_adjustment| to the length of |str|. -void LimitOffsets(const base::string16& str, - std::vector<size_t>* offsets_for_adjustment) { +void LimitOffsets(const base::string16& str, Offsets* offsets_for_adjustment) { if (offsets_for_adjustment) { std::for_each(offsets_for_adjustment->begin(), offsets_for_adjustment->end(), @@ -461,10 +462,9 @@ void LimitOffsets(const base::string16& str, // // We may want to skip this step in the case of file URLs to allow unicode // UNC hostnames regardless of encodings. -base::string16 IDNToUnicodeWithOffsets( - const std::string& host, - const std::string& languages, - std::vector<size_t>* offsets_for_adjustment) { +base::string16 IDNToUnicodeWithOffsets(const std::string& host, + const std::string& languages, + Offsets* offsets_for_adjustment) { // Convert the ASCII input to a base::string16 for ICU. base::string16 input16; input16.reserve(host.length()); @@ -508,52 +508,48 @@ base::string16 IDNToUnicodeWithOffsets( return out16; } -// Transforms |original_offsets| by subtracting |component_begin| from all -// offsets. Any offset which was not at least this large to begin with is set -// to std::string::npos. -std::vector<size_t> OffsetsIntoComponent( - const std::vector<size_t>& original_offsets, - size_t component_begin) { - DCHECK_NE(std::string::npos, component_begin); - std::vector<size_t> offsets_into_component(original_offsets); - for (std::vector<size_t>::iterator i(offsets_into_component.begin()); - i != offsets_into_component.end(); ++i) { - if (*i != std::string::npos) - *i = (*i < component_begin) ? std::string::npos : (*i - component_begin); - } - return offsets_into_component; -} - -// Called after we transform a component and append it to an output string. -// Maps |transformed_offsets|, which represent offsets into the transformed -// component itself, into appropriate offsets for the output string, by adding -// |output_component_begin| to each. Determines which offsets need mapping by -// checking to see which of the |original_offsets| were within the designated -// original component, using its provided endpoints. -void AdjustForComponentTransform( - const std::vector<size_t>& original_offsets, - size_t original_component_begin, - size_t original_component_end, - const std::vector<size_t>& transformed_offsets, - size_t output_component_begin, - std::vector<size_t>* offsets_for_adjustment) { +// Called after transforming a component to set all affected elements in +// |offsets_for_adjustment| to the correct new values. |original_offsets| +// represents the offsets before the transform; |original_component_begin| and +// |original_component_end| represent the pre-transform boundaries of the +// affected component. |transformed_offsets| should be a vector created by +// adjusting |original_offsets| to be relative to the beginning of the component +// in question (via an OffsetAdjuster) and then transformed along with the +// component. Note that any elements in this vector which didn't originally +// point into the component may contain arbitrary values and should be ignored. +// |transformed_component_begin| and |transformed_component_end| are the +// endpoints of the transformed component and are used in combination with the +// two offset vectors to calculate the resulting absolute offsets, which are +// stored in |offsets_for_adjustment|. +void AdjustForComponentTransform(const Offsets& original_offsets, + size_t original_component_begin, + size_t original_component_end, + const Offsets& transformed_offsets, + size_t transformed_component_begin, + size_t transformed_component_end, + Offsets* offsets_for_adjustment) { if (!offsets_for_adjustment) - return; + return; // Nothing to do. - DCHECK_NE(std::string::npos, original_component_begin); - DCHECK_NE(std::string::npos, original_component_end); - DCHECK_NE(base::string16::npos, output_component_begin); - size_t offsets_size = offsets_for_adjustment->size(); - DCHECK_EQ(offsets_size, original_offsets.size()); - DCHECK_EQ(offsets_size, transformed_offsets.size()); - for (size_t i = 0; i < offsets_size; ++i) { + for (size_t i = 0; i < original_offsets.size(); ++i) { size_t original_offset = original_offsets[i]; if ((original_offset >= original_component_begin) && (original_offset < original_component_end)) { + // This offset originally pointed into the transformed component. + // Adjust the transformed relative offset by the new beginning point of + // the transformed component. size_t transformed_offset = transformed_offsets[i]; (*offsets_for_adjustment)[i] = (transformed_offset == base::string16::npos) ? - base::string16::npos : (output_component_begin + transformed_offset); + base::string16::npos : + (transformed_offset + transformed_component_begin); + } else if ((original_offset >= original_component_end) && + (original_offset != std::string::npos)) { + // This offset pointed after the transformed component. Adjust the + // original absolute offset by the difference between the new and old + // component lengths. + (*offsets_for_adjustment)[i] = + original_offset - original_component_end + transformed_component_end; } } } @@ -568,7 +564,7 @@ void AdjustComponent(int delta, url_parse::Component* component) { } // Adjusts all the components of |parsed| by |delta|, except for the scheme. -void AdjustComponents(int delta, url_parse::Parsed* parsed) { +void AdjustAllComponentsButScheme(int delta, url_parse::Parsed* parsed) { AdjustComponent(delta, &(parsed->username)); AdjustComponent(delta, &(parsed->password)); AdjustComponent(delta, &(parsed->host)); @@ -579,27 +575,36 @@ void AdjustComponents(int delta, url_parse::Parsed* parsed) { } // Helper for FormatUrlWithOffsets(). -base::string16 FormatViewSourceUrl( - const GURL& url, - const std::vector<size_t>& original_offsets, - const std::string& languages, - FormatUrlTypes format_types, - UnescapeRule::Type unescape_rules, - url_parse::Parsed* new_parsed, - size_t* prefix_end, - std::vector<size_t>* offsets_for_adjustment) { +base::string16 FormatViewSourceUrl(const GURL& url, + const Offsets& original_offsets, + const std::string& languages, + FormatUrlTypes format_types, + UnescapeRule::Type unescape_rules, + url_parse::Parsed* new_parsed, + size_t* prefix_end, + Offsets* offsets_for_adjustment) { DCHECK(new_parsed); const char kViewSource[] = "view-source:"; const size_t kViewSourceLength = arraysize(kViewSource) - 1; - std::vector<size_t> offsets_into_url( - OffsetsIntoComponent(original_offsets, kViewSourceLength)); - GURL real_url(url.possibly_invalid_spec().substr(kViewSourceLength)); + // Format the underlying URL and adjust offsets. + const std::string& url_str(url.possibly_invalid_spec()); + Offsets offsets_into_underlying_url(original_offsets); + { + base::OffsetAdjuster adjuster(&offsets_into_underlying_url); + adjuster.Add(base::OffsetAdjuster::Adjustment(0, kViewSourceLength, 0)); + } base::string16 result(ASCIIToUTF16(kViewSource) + - FormatUrlWithOffsets(real_url, languages, format_types, unescape_rules, - new_parsed, prefix_end, &offsets_into_url)); + FormatUrlWithOffsets(GURL(url_str.substr(kViewSourceLength)), languages, + format_types, unescape_rules, new_parsed, prefix_end, + &offsets_into_underlying_url)); + AdjustForComponentTransform(original_offsets, kViewSourceLength, + url_str.length(), offsets_into_underlying_url, + kViewSourceLength, result.length(), + offsets_for_adjustment); + LimitOffsets(result, offsets_for_adjustment); - // Adjust position values. + // Adjust positions of the parsed components. if (new_parsed->scheme.is_nonempty()) { // Assume "view-source:real-scheme" as a scheme. new_parsed->scheme.len += kViewSourceLength; @@ -607,13 +612,11 @@ base::string16 FormatViewSourceUrl( new_parsed->scheme.begin = 0; new_parsed->scheme.len = kViewSourceLength - 1; } - AdjustComponents(kViewSourceLength, new_parsed); + AdjustAllComponentsButScheme(kViewSourceLength, new_parsed); + if (prefix_end) *prefix_end += kViewSourceLength; - AdjustForComponentTransform(original_offsets, kViewSourceLength, - url.possibly_invalid_spec().length(), offsets_into_url, kViewSourceLength, - offsets_for_adjustment); - LimitOffsets(result, offsets_for_adjustment); + return result; } @@ -622,9 +625,8 @@ class AppendComponentTransform { AppendComponentTransform() {} virtual ~AppendComponentTransform() {} - virtual base::string16 Execute( - const std::string& component_text, - std::vector<size_t>* offsets_into_component) const = 0; + virtual base::string16 Execute(const std::string& component_text, + Offsets* offsets_into_component) const = 0; // NOTE: No DISALLOW_COPY_AND_ASSIGN here, since gcc < 4.3.0 requires an // accessible copy constructor in order to call AppendFormattedComponent() @@ -640,7 +642,7 @@ class HostComponentTransform : public AppendComponentTransform { private: virtual base::string16 Execute( const std::string& component_text, - std::vector<size_t>* offsets_into_component) const OVERRIDE { + Offsets* offsets_into_component) const OVERRIDE { return IDNToUnicodeWithOffsets(component_text, languages_, offsets_into_component); } @@ -657,7 +659,7 @@ class NonHostComponentTransform : public AppendComponentTransform { private: virtual base::string16 Execute( const std::string& component_text, - std::vector<size_t>* offsets_into_component) const OVERRIDE { + Offsets* offsets_into_component) const OVERRIDE { return (unescape_rules_ == UnescapeRule::NONE) ? base::UTF8ToUTF16AndAdjustOffsets(component_text, offsets_into_component) : @@ -668,34 +670,46 @@ class NonHostComponentTransform : public AppendComponentTransform { const UnescapeRule::Type unescape_rules_; }; +// Transforms the portion of |spec| covered by |original_component| according to +// |transform|. Appends the result to |output|. If |output_component| is +// non-NULL, its start and length are set to the transformed component's new +// start and length. For each element in |original_offsets| which is at least +// as large as original_component.begin, the corresponding element of +// |offsets_for_adjustment| is transformed appropriately. void AppendFormattedComponent(const std::string& spec, const url_parse::Component& original_component, - const std::vector<size_t>& original_offsets, + const Offsets& original_offsets, const AppendComponentTransform& transform, base::string16* output, url_parse::Component* output_component, - std::vector<size_t>* offsets_for_adjustment) { + Offsets* offsets_for_adjustment) { DCHECK(output); if (original_component.is_nonempty()) { size_t original_component_begin = static_cast<size_t>(original_component.begin); size_t output_component_begin = output->length(); - if (output_component) - output_component->begin = static_cast<int>(output_component_begin); - - std::vector<size_t> offsets_into_component = - OffsetsIntoComponent(original_offsets, original_component_begin); - output->append(transform.Execute(std::string(spec, original_component_begin, - static_cast<size_t>(original_component.len)), &offsets_into_component)); + std::string component_str(spec, original_component_begin, + static_cast<size_t>(original_component.len)); + + // Transform |component_str| and adjust the offsets accordingly. + Offsets offsets_into_component(original_offsets); + { + base::OffsetAdjuster adjuster(&offsets_into_component); + adjuster.Add(base::OffsetAdjuster::Adjustment(0, original_component_begin, + 0)); + } + output->append(transform.Execute(component_str, &offsets_into_component)); + AdjustForComponentTransform(original_offsets, original_component_begin, + static_cast<size_t>(original_component.end()), + offsets_into_component, output_component_begin, + output->length(), offsets_for_adjustment); + // Set positions of the parsed component. if (output_component) { + output_component->begin = static_cast<int>(output_component_begin); output_component->len = static_cast<int>(output->length() - output_component_begin); } - AdjustForComponentTransform(original_offsets, original_component_begin, - static_cast<size_t>(original_component.end()), - offsets_into_component, output_component_begin, - offsets_for_adjustment); } else if (output_component) { output_component->reset(); } @@ -899,6 +913,28 @@ bool FilePathToString16(const base::FilePath& path, base::string16* converted) { #endif } +bool IPNumberPrefixCheck(const IPAddressNumber& ip_number, + const unsigned char* ip_prefix, + size_t prefix_length_in_bits) { + // Compare all the bytes that fall entirely within the prefix. + int num_entire_bytes_in_prefix = prefix_length_in_bits / 8; + for (int i = 0; i < num_entire_bytes_in_prefix; ++i) { + if (ip_number[i] != ip_prefix[i]) + return false; + } + + // In case the prefix was not a multiple of 8, there will be 1 byte + // which is only partially masked. + int remaining_bits = prefix_length_in_bits % 8; + if (remaining_bits != 0) { + unsigned char mask = 0xFF << (8 - remaining_bits); + int i = num_entire_bytes_in_prefix; + if ((ip_number[i] & mask) != (ip_prefix[i] & mask)) + return false; + } + return true; +} + } // namespace const FormatUrlType kFormatUrlOmitNothing = 0; @@ -1130,12 +1166,8 @@ bool IsSafePortablePathComponent(const base::FilePath& component) { FilePathToString16(component, &component16) && file_util::IsFilenameLegal(component16) && !IsShellIntegratedExtension(extension) && - (sanitized == component.value()); -} - -bool IsSafePortableBasename(const base::FilePath& filename) { - return IsSafePortablePathComponent(filename) && - !IsReservedName(filename.value()); + (sanitized == component.value()) && + !IsReservedName(component.value()); } bool IsSafePortableRelativePath(const base::FilePath& path) { @@ -1149,7 +1181,7 @@ bool IsSafePortableRelativePath(const base::FilePath& path) { if (!IsSafePortablePathComponent(base::FilePath(components[i]))) return false; } - return IsSafePortableBasename(path.BaseName()); + return IsSafePortablePathComponent(path.BaseName()); } void GenerateSafeFileName(const std::string& mime_type, @@ -1391,7 +1423,6 @@ std::string GetHostAndOptionalPort(const GURL& url) { return url.host(); } -// static bool IsHostnameNonUnique(const std::string& hostname) { // CanonicalizeHost requires surrounding brackets to parse an IPv6 address. const std::string host_or_ip = hostname.find(':') != std::string::npos ? @@ -1404,11 +1435,24 @@ bool IsHostnameNonUnique(const std::string& hostname) { if (canonical_name.empty()) return false; - // If |hostname| is an IP address, presume it's unique. - // TODO(rsleevi): In the future, this should also reject IP addresses in - // IANA-reserved ranges. - if (host_info.IsIPAddress()) - return false; + // If |hostname| is an IP address, check to see if it's in an IANA-reserved + // range. + if (host_info.IsIPAddress()) { + IPAddressNumber host_addr; + if (!ParseIPLiteralToNumber(hostname.substr(host_info.out_host.begin, + host_info.out_host.len), + &host_addr)) { + return false; + } + switch (host_info.family) { + case url_canon::CanonHostInfo::IPV4: + case url_canon::CanonHostInfo::IPV6: + return IsIPAddressReserved(host_addr); + case url_canon::CanonHostInfo::NEUTRAL: + case url_canon::CanonHostInfo::BROKEN: + return false; + } + } // Check for a registry controlled portion of |hostname|, ignoring private // registries, as they already chain to ICANN-administered registries, @@ -1425,6 +1469,57 @@ bool IsHostnameNonUnique(const std::string& hostname) { registry_controlled_domains::EXCLUDE_PRIVATE_REGISTRIES); } +// Don't compare IPv4 and IPv6 addresses (they have different range +// reservations). Keep separate reservation arrays for each IP type, and +// consolidate adjacent reserved ranges within a reservation array when +// possible. +// Sources for info: +// www.iana.org/assignments/ipv4-address-space/ipv4-address-space.xhtml +// www.iana.org/assignments/ipv6-address-space/ipv6-address-space.xhtml +// They're formatted here with the prefix as the last element. For example: +// 10.0.0.0/8 becomes 10,0,0,0,8 and fec0::/10 becomes 0xfe,0xc0,0,0,0...,10. +bool IsIPAddressReserved(const IPAddressNumber& host_addr) { + static const unsigned char kReservedIPv4[][5] = { + { 0,0,0,0,8 }, { 10,0,0,0,8 }, { 100,64,0,0,10 }, { 127,0,0,0,8 }, + { 169,254,0,0,16 }, { 172,16,0,0,12 }, { 192,0,2,0,24 }, + { 192,88,99,0,24 }, { 192,168,0,0,16 }, { 198,18,0,0,15 }, + { 198,51,100,0,24 }, { 203,0,113,0,24 }, { 224,0,0,0,3 } + }; + static const unsigned char kReservedIPv6[][17] = { + { 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,8 }, + { 0x40,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 }, + { 0x80,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2 }, + { 0xc0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3 }, + { 0xe0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,4 }, + { 0xf0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,5 }, + { 0xf8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,6 }, + { 0xfc,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,7 }, + { 0xfe,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,9 }, + { 0xfe,0x80,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10 }, + { 0xfe,0xc0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10 }, + }; + size_t array_size = 0; + const unsigned char* array = NULL; + switch (host_addr.size()) { + case kIPv4AddressSize: + array_size = arraysize(kReservedIPv4); + array = kReservedIPv4[0]; + break; + case kIPv6AddressSize: + array_size = arraysize(kReservedIPv6); + array = kReservedIPv6[0]; + break; + } + if (!array) + return false; + size_t width = host_addr.size() + 1; + for (size_t i = 0; i < array_size; ++i, array += width) { + if (IPNumberPrefixCheck(host_addr, array, array[width-1])) + return true; + } + return false; +} + // Extracts the address and port portions of a sockaddr. bool GetIPAddressFromSockAddr(const struct sockaddr* sock_addr, socklen_t sock_addr_len, @@ -1521,6 +1616,11 @@ std::string IPAddressToStringWithPort(const IPAddressNumber& addr, return IPAddressToStringWithPort(&addr.front(), addr.size(), port); } +std::string IPAddressToPackedString(const IPAddressNumber& addr) { + return std::string(reinterpret_cast<const char *>(&addr.front()), + addr.size()); +} + std::string GetHostName() { #if defined(OS_WIN) EnsureWinsockInit(); @@ -1552,7 +1652,7 @@ std::string GetHostOrSpecFromURL(const GURL& url) { void AppendFormattedHost(const GURL& url, const std::string& languages, base::string16* output) { - std::vector<size_t> offsets; + Offsets offsets; AppendFormattedComponent(url.possibly_invalid_spec(), url.parsed_for_possibly_invalid_spec().host, offsets, HostComponentTransform(languages), output, NULL, NULL); @@ -1565,13 +1665,13 @@ base::string16 FormatUrlWithOffsets( UnescapeRule::Type unescape_rules, url_parse::Parsed* new_parsed, size_t* prefix_end, - std::vector<size_t>* offsets_for_adjustment) { + Offsets* offsets_for_adjustment) { url_parse::Parsed parsed_temp; if (!new_parsed) new_parsed = &parsed_temp; else *new_parsed = url_parse::Parsed(); - std::vector<size_t> original_offsets; + Offsets original_offsets; if (offsets_for_adjustment) original_offsets = *offsets_for_adjustment; @@ -1583,7 +1683,8 @@ base::string16 FormatUrlWithOffsets( if (url.SchemeIs(kViewSource) && !StartsWithASCII(url.possibly_invalid_spec(), kViewSourceTwice, false)) { return FormatViewSourceUrl(url, original_offsets, languages, format_types, - unescape_rules, new_parsed, prefix_end, offsets_for_adjustment); + unescape_rules, new_parsed, prefix_end, + offsets_for_adjustment); } // We handle both valid and invalid URLs (this will give us the spec @@ -1641,32 +1742,13 @@ base::string16 FormatUrlWithOffsets( AppendFormattedComponent(spec, parsed.username, original_offsets, NonHostComponentTransform(unescape_rules), &url_string, &new_parsed->username, offsets_for_adjustment); - if (parsed.password.is_valid()) { - size_t colon = parsed.username.end(); - DCHECK_EQ(static_cast<size_t>(parsed.password.begin - 1), colon); - std::vector<size_t>::const_iterator colon_iter = - std::find(original_offsets.begin(), original_offsets.end(), colon); - if (colon_iter != original_offsets.end()) { - (*offsets_for_adjustment)[colon_iter - original_offsets.begin()] = - url_string.length(); - } + if (parsed.password.is_valid()) url_string.push_back(':'); - } AppendFormattedComponent(spec, parsed.password, original_offsets, NonHostComponentTransform(unescape_rules), &url_string, &new_parsed->password, offsets_for_adjustment); - if (parsed.username.is_valid() || parsed.password.is_valid()) { - size_t at_sign = (parsed.password.is_valid() ? - parsed.password : parsed.username).end(); - DCHECK_EQ(static_cast<size_t>(parsed.host.begin - 1), at_sign); - std::vector<size_t>::const_iterator at_sign_iter = - std::find(original_offsets.begin(), original_offsets.end(), at_sign); - if (at_sign_iter != original_offsets.end()) { - (*offsets_for_adjustment)[at_sign_iter - original_offsets.begin()] = - url_string.length(); - } + if (parsed.username.is_valid() || parsed.password.is_valid()) url_string.push_back('@'); - } } if (prefix_end) *prefix_end = static_cast<size_t>(url_string.length()); @@ -1694,6 +1776,10 @@ base::string16 FormatUrlWithOffsets( AppendFormattedComponent(spec, parsed.path, original_offsets, NonHostComponentTransform(unescape_rules), &url_string, &new_parsed->path, offsets_for_adjustment); + } else { + base::OffsetAdjuster offset_adjuster(offsets_for_adjustment); + offset_adjuster.Add(base::OffsetAdjuster::Adjustment( + url_string.length(), parsed.path.len, 0)); } if (parsed.query.is_valid()) url_string.push_back('?'); @@ -1702,26 +1788,11 @@ base::string16 FormatUrlWithOffsets( &new_parsed->query, offsets_for_adjustment); // Ref. This is valid, unescaped UTF-8, so we can just convert. - if (parsed.ref.is_valid()) { + if (parsed.ref.is_valid()) url_string.push_back('#'); - size_t original_ref_begin = static_cast<size_t>(parsed.ref.begin); - size_t output_ref_begin = url_string.length(); - new_parsed->ref.begin = static_cast<int>(output_ref_begin); - - std::vector<size_t> offsets_into_ref( - OffsetsIntoComponent(original_offsets, original_ref_begin)); - if (parsed.ref.len > 0) { - url_string.append(base::UTF8ToUTF16AndAdjustOffsets( - spec.substr(original_ref_begin, static_cast<size_t>(parsed.ref.len)), - &offsets_into_ref)); - } - - new_parsed->ref.len = - static_cast<int>(url_string.length() - new_parsed->ref.begin); - AdjustForComponentTransform(original_offsets, original_ref_begin, - static_cast<size_t>(parsed.ref.end()), offsets_into_ref, - output_ref_begin, offsets_for_adjustment); - } + AppendFormattedComponent(spec, parsed.ref, original_offsets, + NonHostComponentTransform(UnescapeRule::NONE), &url_string, + &new_parsed->ref, offsets_for_adjustment); // If we need to strip out http do it after the fact. This way we don't need // to worry about how offset_for_adjustment is interpreted. @@ -1739,7 +1810,7 @@ base::string16 FormatUrlWithOffsets( DCHECK(new_parsed->scheme.is_valid()); int delta = -(new_parsed->scheme.len + 3); // +3 for ://. new_parsed->scheme.reset(); - AdjustComponents(delta, new_parsed); + AdjustAllComponentsButScheme(delta, new_parsed); } LimitOffsets(url_string, offsets_for_adjustment); @@ -1753,7 +1824,7 @@ base::string16 FormatUrl(const GURL& url, url_parse::Parsed* new_parsed, size_t* prefix_end, size_t* offset_for_adjustment) { - std::vector<size_t> offsets; + Offsets offsets; if (offset_for_adjustment) offsets.push_back(*offset_for_adjustment); base::string16 result = FormatUrlWithOffsets(url, languages, format_types, @@ -1884,6 +1955,19 @@ AddressFamily GetAddressFamily(const IPAddressNumber& address) { } } +int ConvertAddressFamily(AddressFamily address_family) { + switch (address_family) { + case ADDRESS_FAMILY_UNSPECIFIED: + return AF_UNSPEC; + case ADDRESS_FAMILY_IPV4: + return AF_INET; + case ADDRESS_FAMILY_IPV6: + return AF_INET6; + } + NOTREACHED(); + return AF_UNSPEC; +} + bool ParseIPLiteralToNumber(const std::string& ip_literal, IPAddressNumber* ip_number) { // |ip_literal| could be either a IPv4 or an IPv6 literal. If it contains @@ -1996,25 +2080,7 @@ bool IPNumberMatchesPrefix(const IPAddressNumber& ip_number, 96 + prefix_length_in_bits); } - // Otherwise we are comparing two IPv4 addresses, or two IPv6 addresses. - // Compare all the bytes that fall entirely within the prefix. - int num_entire_bytes_in_prefix = prefix_length_in_bits / 8; - for (int i = 0; i < num_entire_bytes_in_prefix; ++i) { - if (ip_number[i] != ip_prefix[i]) - return false; - } - - // In case the prefix was not a multiple of 8, there will be 1 byte - // which is only partially masked. - int remaining_bits = prefix_length_in_bits % 8; - if (remaining_bits != 0) { - unsigned char mask = 0xFF << (8 - remaining_bits); - int i = num_entire_bytes_in_prefix; - if ((ip_number[i] & mask) != (ip_prefix[i] & mask)) - return false; - } - - return true; + return IPNumberPrefixCheck(ip_number, &ip_prefix[0], prefix_length_in_bits); } const uint16* GetPortFieldFromSockaddr(const struct sockaddr* address, diff --git a/chromium/net/base/net_util.h b/chromium/net/base/net_util.h index 839735e5cd5..6b8884dd114 100644 --- a/chromium/net/base/net_util.h +++ b/chromium/net/base/net_util.h @@ -112,12 +112,15 @@ NET_EXPORT std::string GetHostAndPort(const GURL& url); NET_EXPORT_PRIVATE std::string GetHostAndOptionalPort(const GURL& url); // Returns true if |hostname| contains a non-registerable or non-assignable -// domain name (eg: a gTLD that has not been assigned by IANA) -// -// TODO(rsleevi): http://crbug.com/119212 - Also match internal IP -// address ranges. +// domain name (eg: a gTLD that has not been assigned by IANA) or an IP address +// that falls in an IANA-reserved range. NET_EXPORT bool IsHostnameNonUnique(const std::string& hostname); +// Returns true if an IP address hostname is in a range reserved by the IANA. +// Works with both IPv4 and IPv6 addresses, and only compares against a given +// protocols's reserved ranges. +NET_EXPORT bool IsIPAddressReserved(const IPAddressNumber& address); + // Convenience struct for when you need a |struct sockaddr|. struct SockaddrStorage { SockaddrStorage() : addr_len(sizeof(addr_storage)), @@ -163,6 +166,9 @@ NET_EXPORT std::string IPAddressToString(const IPAddressNumber& addr); NET_EXPORT std::string IPAddressToStringWithPort( const IPAddressNumber& addr, uint16 port); +// Returns the address as a sequence of bytes in network-byte-order. +NET_EXPORT std::string IPAddressToPackedString(const IPAddressNumber& addr); + // Returns the hostname of the current system. Returns empty string on failure. NET_EXPORT std::string GetHostName(); @@ -295,20 +301,17 @@ NET_EXPORT base::FilePath GenerateFileName( const std::string& mime_type, const std::string& default_name); -// Valid basenames: +// Valid components: // * are not empty // * are not Windows reserved names (CON, NUL.zip, etc.) -// * are just basenames // * do not have trailing separators // * do not equal kCurrentDirectory // * do not reference the parent directory -// * are valid path components, which: -// - * are not the empty string -// - * do not contain illegal characters -// - * do not end with Windows shell-integrated extensions (even on posix) -// - * do not begin with '.' (which would hide them in most file managers) -// - * do not end with ' ' or '.' -NET_EXPORT bool IsSafePortableBasename(const base::FilePath& path); +// * do not contain illegal characters +// * do not end with Windows shell-integrated extensions (even on posix) +// * do not begin with '.' (which would hide them in most file managers) +// * do not end with ' ' or '.' +NET_EXPORT bool IsSafePortablePathComponent(const base::FilePath& component); // Basenames of valid relative paths are IsSafePortableBasename(), and internal // path components of valid relative paths are valid path components as @@ -367,18 +370,28 @@ NET_EXPORT void AppendFormattedHost(const GURL& url, // UTF-8, decodes %-encoding and UTF-8. // // The last three parameters may be NULL. +// // |new_parsed| will be set to the parsing parameters of the resultant URL. +// // |prefix_end| will be the length before the hostname of the resultant URL. // -// (|offset[s]_for_adjustment|) specifies one or more offsets into the original -// |url|'s spec(); each offset will be modified to reflect changes this function -// makes to the output string. For example, if |url| is "http://a:b@c.com/", -// |omit_username_password| is true, and an offset is 12 (the offset of '.'), -// then on return the output string will be "http://c.com/" and the offset will -// be 8. If an offset cannot be successfully adjusted (e.g. because it points -// into the middle of a component that was entirely removed, past the end of the -// string, or into the middle of an encoding sequence), it will be set to -// base::string16::npos. +// |offset[s]_for_adjustment| specifies one or more offsets into the original +// URL, representing insertion or selection points between characters: if the +// input is "http://foo.com/", offset 0 is before the entire URL, offset 7 is +// between the scheme and the host, and offset 15 is after the end of the URL. +// Valid input offsets range from 0 to the length of the input URL string. On +// exit, each offset will have been modified to reflect any changes made to the +// output string. For example, if |url| is "http://a:b@c.com/", +// |omit_username_password| is true, and an offset is 12 (pointing between 'c' +// and '.'), then on return the output string will be "http://c.com/" and the +// offset will be 8. If an offset cannot be successfully adjusted (e.g. because +// it points into the middle of a component that was entirely removed or into +// the middle of an encoding sequence), it will be set to base::string16::npos. +// For consistency, if an input offset points between the scheme and the +// username/password, and both are removed, on output this offset will be 0 +// rather than npos; this means that offsets at the starts and ends of removed +// components are always transformed the same way regardless of what other +// components are adjacent. NET_EXPORT base::string16 FormatUrl(const GURL& url, const std::string& languages, FormatUrlTypes format_types, @@ -435,6 +448,9 @@ bool HaveOnlyLoopbackAddresses(); NET_EXPORT_PRIVATE AddressFamily GetAddressFamily( const IPAddressNumber& address); +// Maps the given AddressFamily to either AF_INET, AF_INET6 or AF_UNSPEC. +int ConvertAddressFamily(AddressFamily address_family); + // Parses an IP address literal (either IPv4 or IPv6) to its numeric value. // Returns true on success and fills |ip_number| with the numeric value. NET_EXPORT_PRIVATE bool ParseIPLiteralToNumber(const std::string& ip_literal, diff --git a/chromium/net/base/net_util_unittest.cc b/chromium/net/base/net_util_unittest.cc index 7c6bb572d77..207e05a1cc0 100644 --- a/chromium/net/base/net_util_unittest.cc +++ b/chromium/net/base/net_util_unittest.cc @@ -452,26 +452,24 @@ void CheckAdjustedOffsets(const std::string& url_string, const std::string& languages, FormatUrlTypes format_types, UnescapeRule::Type unescape_rules, - const AdjustOffsetCase* cases, - size_t num_cases, - const size_t* all_offsets) { + const size_t* output_offsets) { GURL url(url_string); - for (size_t i = 0; i < num_cases; ++i) { - size_t offset = cases[i].input_offset; - base::string16 formatted_url = FormatUrl(url, languages, format_types, - unescape_rules, NULL, NULL, &offset); - VerboseExpect(cases[i].output_offset, offset, url_string, i, formatted_url); - } - - size_t url_size = url_string.length(); + size_t url_length = url_string.length(); std::vector<size_t> offsets; - for (size_t i = 0; i < url_size + 1; ++i) + for (size_t i = 0; i <= url_length + 1; ++i) offsets.push_back(i); + offsets.push_back(500000); // Something larger than any input length. + offsets.push_back(std::string::npos); base::string16 formatted_url = FormatUrlWithOffsets(url, languages, format_types, unescape_rules, NULL, NULL, &offsets); - for (size_t i = 0; i < url_size; ++i) - VerboseExpect(all_offsets[i], offsets[i], url_string, i, formatted_url); - VerboseExpect(kNpos, offsets[url_size], url_string, url_size, formatted_url); + for (size_t i = 0; i < url_length; ++i) + VerboseExpect(output_offsets[i], offsets[i], url_string, i, formatted_url); + VerboseExpect(formatted_url.length(), offsets[url_length], url_string, + url_length, formatted_url); + VerboseExpect(base::string16::npos, offsets[url_length + 1], url_string, + 500000, formatted_url); + VerboseExpect(base::string16::npos, offsets[url_length + 2], url_string, + std::string::npos, formatted_url); } // Helper to strignize an IP number (used to define expectations). @@ -2886,217 +2884,115 @@ TEST(NetUtilTest, FormatUrlRoundTripQueryEscaped) { } TEST(NetUtilTest, FormatUrlWithOffsets) { - const AdjustOffsetCase null_cases[] = { - {0, base::string16::npos}, - }; CheckAdjustedOffsets(std::string(), "en", kFormatUrlOmitNothing, - UnescapeRule::NORMAL, null_cases, arraysize(null_cases), NULL); - - const AdjustOffsetCase basic_cases[] = { - {0, 0}, - {3, 3}, - {5, 5}, - {6, 6}, - {13, 13}, - {21, 21}, - {22, 22}, - {23, 23}, - {25, 25}, - {26, base::string16::npos}, - {500000, base::string16::npos}, - {base::string16::npos, base::string16::npos}, + UnescapeRule::NORMAL, NULL); + + const size_t basic_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25 }; - const size_t basic_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; CheckAdjustedOffsets("http://www.google.com/foo/", "en", - kFormatUrlOmitNothing, UnescapeRule::NORMAL, basic_cases, - arraysize(basic_cases), basic_offsets); - - const AdjustOffsetCase omit_auth_cases_1[] = { - {6, 6}, - {7, base::string16::npos}, - {8, base::string16::npos}, - {10, base::string16::npos}, - {12, base::string16::npos}, - {14, base::string16::npos}, - {15, 7}, - {25, 17}, + kFormatUrlOmitNothing, UnescapeRule::NORMAL, + basic_offsets); + + const size_t omit_auth_offsets_1[] = { + 0, 1, 2, 3, 4, 5, 6, 7, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 7, + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 }; - const size_t omit_auth_offsets_1[] = {0, 1, 2, 3, 4, 5, 6, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21}; CheckAdjustedOffsets("http://foo:bar@www.google.com/", "en", - kFormatUrlOmitUsernamePassword, UnescapeRule::NORMAL, omit_auth_cases_1, - arraysize(omit_auth_cases_1), omit_auth_offsets_1); + kFormatUrlOmitUsernamePassword, UnescapeRule::NORMAL, + omit_auth_offsets_1); - const AdjustOffsetCase omit_auth_cases_2[] = { - {9, base::string16::npos}, - {11, 7}, + const size_t omit_auth_offsets_2[] = { + 0, 1, 2, 3, 4, 5, 6, 7, kNpos, kNpos, kNpos, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21 }; - const size_t omit_auth_offsets_2[] = {0, 1, 2, 3, 4, 5, 6, kNpos, kNpos, - kNpos, kNpos, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}; CheckAdjustedOffsets("http://foo@www.google.com/", "en", - kFormatUrlOmitUsernamePassword, UnescapeRule::NORMAL, omit_auth_cases_2, - arraysize(omit_auth_cases_2), omit_auth_offsets_2); - - // "http://foo\x30B0:\x30B0bar@www.google.com" - const AdjustOffsetCase dont_omit_auth_cases[] = { - {0, 0}, - /*{3, base::string16::npos}, - {7, 0}, - {11, 4}, - {12, base::string16::npos}, - {20, 5}, - {24, 9},*/ + kFormatUrlOmitUsernamePassword, UnescapeRule::NORMAL, + omit_auth_offsets_2); + + const size_t dont_omit_auth_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, kNpos, 11, 12, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31 }; - const size_t dont_omit_auth_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 11, 12, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; + // Unescape to "http://foo\x30B0:\x30B0bar@www.google.com". CheckAdjustedOffsets("http://foo%E3%82%B0:%E3%82%B0bar@www.google.com/", "en", - kFormatUrlOmitNothing, UnescapeRule::NORMAL, dont_omit_auth_cases, - arraysize(dont_omit_auth_cases), dont_omit_auth_offsets); - - const AdjustOffsetCase view_source_cases[] = { - {0, 0}, - {3, 3}, - {11, 11}, - {12, 12}, - {13, 13}, - {18, 18}, - {19, base::string16::npos}, - {20, base::string16::npos}, - {23, 19}, - {26, 22}, - {base::string16::npos, base::string16::npos}, + kFormatUrlOmitNothing, UnescapeRule::NORMAL, + dont_omit_auth_offsets); + + const size_t view_source_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, kNpos, + kNpos, kNpos, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33 }; - const size_t view_source_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, kNpos, kNpos, kNpos, kNpos, 19, 20, 21, 22, - 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33}; CheckAdjustedOffsets("view-source:http://foo@www.google.com/", "en", - kFormatUrlOmitUsernamePassword, UnescapeRule::NORMAL, view_source_cases, - arraysize(view_source_cases), view_source_offsets); - - // "http://\x671d\x65e5\x3042\x3055\x3072.jp/foo/" - const AdjustOffsetCase idn_hostname_cases_1[] = { - {8, base::string16::npos}, - {16, base::string16::npos}, - {24, base::string16::npos}, - {25, 12}, - {30, 17}, + kFormatUrlOmitUsernamePassword, UnescapeRule::NORMAL, + view_source_offsets); + + const size_t idn_hostname_offsets_1[] = { + 0, 1, 2, 3, 4, 5, 6, 7, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 12, + 13, 14, 15, 16, 17, 18, 19 }; - const size_t idn_hostname_offsets_1[] = {0, 1, 2, 3, 4, 5, 6, 7, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, 12, 13, 14, 15, 16, 17, 18, 19}; + // Convert punycode to "http://\x671d\x65e5\x3042\x3055\x3072.jp/foo/". CheckAdjustedOffsets("http://xn--l8jvb1ey91xtjb.jp/foo/", "ja", - kFormatUrlOmitNothing, UnescapeRule::NORMAL, idn_hostname_cases_1, - arraysize(idn_hostname_cases_1), idn_hostname_offsets_1); - - // "http://test.\x89c6\x9891.\x5317\x4eac\x5927\x5b78.test/" - const AdjustOffsetCase idn_hostname_cases_2[] = { - {7, 7}, - {9, 9}, - {11, 11}, - {12, 12}, - {13, base::string16::npos}, - {23, base::string16::npos}, - {24, 14}, - {25, 15}, - {26, base::string16::npos}, - {32, base::string16::npos}, - {41, 19}, - {42, 20}, - {45, 23}, - {46, 24}, - {47, base::string16::npos}, - {base::string16::npos, base::string16::npos}, + kFormatUrlOmitNothing, UnescapeRule::NORMAL, + idn_hostname_offsets_1); + + const size_t idn_hostname_offsets_2[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 14, 15, kNpos, kNpos, kNpos, + kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, 19, 20, 21, 22, 23, 24 }; - const size_t idn_hostname_offsets_2[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, - kNpos, 14, 15, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 19, 20, 21, 22, 23, 24}; + // Convert punycode to + // "http://test.\x89c6\x9891.\x5317\x4eac\x5927\x5b78.test/". CheckAdjustedOffsets("http://test.xn--cy2a840a.xn--1lq90ic7f1rc.test/", "zh-CN", kFormatUrlOmitNothing, UnescapeRule::NORMAL, - idn_hostname_cases_2, arraysize(idn_hostname_cases_2), idn_hostname_offsets_2); - // "http://www.google.com/foo bar/\x30B0\x30FC\x30B0\x30EB" - const AdjustOffsetCase unescape_cases[] = { - {25, 25}, - {26, base::string16::npos}, - {27, base::string16::npos}, - {28, 26}, - {35, base::string16::npos}, - {41, 31}, - {59, 33}, - {60, base::string16::npos}, - {67, base::string16::npos}, - {68, base::string16::npos}, + const size_t unescape_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, kNpos, kNpos, 26, 27, 28, 29, 30, kNpos, kNpos, kNpos, + kNpos, kNpos, kNpos, kNpos, kNpos, 31, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, kNpos, kNpos, 32, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, + kNpos, 33, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos }; - const size_t unescape_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, kNpos, kNpos, 26, 27, - 28, 29, 30, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 31, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 32, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 33, kNpos, kNpos, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos}; + // Unescape to "http://www.google.com/foo bar/\x30B0\x30FC\x30B0\x30EB". CheckAdjustedOffsets( "http://www.google.com/foo%20bar/%E3%82%B0%E3%83%BC%E3%82%B0%E3%83%AB", - "en", kFormatUrlOmitNothing, UnescapeRule::SPACES, unescape_cases, - arraysize(unescape_cases), unescape_offsets); - - // "http://www.google.com/foo.html#\x30B0\x30B0z" - const AdjustOffsetCase ref_cases[] = { - {30, 30}, - {31, 31}, - {32, base::string16::npos}, - {34, 32}, - {35, base::string16::npos}, - {37, 33}, - {38, base::string16::npos}, + "en", kFormatUrlOmitNothing, UnescapeRule::SPACES, unescape_offsets); + + const size_t ref_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, kNpos, kNpos, 32, kNpos, kNpos, + 33 }; - const size_t ref_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, - 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - kNpos, kNpos, 32, kNpos, kNpos, 33}; + // Unescape to "http://www.google.com/foo.html#\x30B0\x30B0z". CheckAdjustedOffsets( "http://www.google.com/foo.html#\xE3\x82\xB0\xE3\x82\xB0z", "en", - kFormatUrlOmitNothing, UnescapeRule::NORMAL, ref_cases, - arraysize(ref_cases), ref_offsets); - - const AdjustOffsetCase omit_http_cases[] = { - {0, base::string16::npos}, - {3, base::string16::npos}, - {7, 0}, - {8, 1}, + kFormatUrlOmitNothing, UnescapeRule::NORMAL, ref_offsets); + + const size_t omit_http_offsets[] = { + 0, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14 }; - const size_t omit_http_offsets[] = {kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, - kNpos, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}; - CheckAdjustedOffsets("http://www.google.com/", "en", - kFormatUrlOmitHTTP, UnescapeRule::NORMAL, omit_http_cases, - arraysize(omit_http_cases), omit_http_offsets); - - const AdjustOffsetCase omit_http_start_with_ftp_cases[] = { - {0, 0}, - {3, 3}, - {8, 8}, + CheckAdjustedOffsets("http://www.google.com/", "en", kFormatUrlOmitHTTP, + UnescapeRule::NORMAL, omit_http_offsets); + + const size_t omit_http_start_with_ftp_offsets[] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 }; - const size_t omit_http_start_with_ftp_offsets[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}; CheckAdjustedOffsets("http://ftp.google.com/", "en", kFormatUrlOmitHTTP, - UnescapeRule::NORMAL, omit_http_start_with_ftp_cases, - arraysize(omit_http_start_with_ftp_cases), - omit_http_start_with_ftp_offsets); - - const AdjustOffsetCase omit_all_cases[] = { - {12, 0}, - {13, 1}, - {0, base::string16::npos}, - {3, base::string16::npos}, + UnescapeRule::NORMAL, omit_http_start_with_ftp_offsets); + + const size_t omit_all_offsets[] = { + 0, kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 0, kNpos, kNpos, kNpos, kNpos, + 0, 1, 2, 3, 4, 5, 6, 7 }; - const size_t omit_all_offsets[] = {kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, - kNpos, kNpos, kNpos, kNpos, kNpos, kNpos, 0, 1, 2, 3, 4, 5, 6, kNpos}; CheckAdjustedOffsets("http://user@foo.com/", "en", kFormatUrlOmitAll, - UnescapeRule::NORMAL, omit_all_cases, - arraysize(omit_all_cases), omit_all_offsets); + UnescapeRule::NORMAL, omit_all_offsets); } TEST(NetUtilTest, SimplifyUrlForRequest) { @@ -3458,17 +3354,17 @@ static const base::FilePath::CharType* kSafePortableRelativePaths[] = { #endif }; -TEST(NetUtilTest, IsSafePortableBasename) { +TEST(NetUtilTest, IsSafePortablePathComponent) { for (size_t i = 0 ; i < arraysize(kSafePortableBasenames); ++i) { - EXPECT_TRUE(IsSafePortableBasename(base::FilePath( + EXPECT_TRUE(IsSafePortablePathComponent(base::FilePath( kSafePortableBasenames[i]))) << kSafePortableBasenames[i]; } for (size_t i = 0 ; i < arraysize(kUnsafePortableBasenames); ++i) { - EXPECT_FALSE(IsSafePortableBasename(base::FilePath( + EXPECT_FALSE(IsSafePortablePathComponent(base::FilePath( kUnsafePortableBasenames[i]))) << kUnsafePortableBasenames[i]; } for (size_t i = 0 ; i < arraysize(kSafePortableRelativePaths); ++i) { - EXPECT_FALSE(IsSafePortableBasename(base::FilePath( + EXPECT_FALSE(IsSafePortablePathComponent(base::FilePath( kSafePortableRelativePaths[i]))) << kSafePortableRelativePaths[i]; } } @@ -3517,15 +3413,27 @@ const NonUniqueNameTestData kNonUniqueNameTestData[] = { // Domains under private registries. { true, "appspot.com" }, { true, "test.appspot.com" }, - // IPv4 addresses (in various forms). + // Unreserved IPv4 addresses (in various forms). { true, "8.8.8.8" }, - { true, "1.2.3" }, - { true, "14.15" }, - { true, "676768" }, - // IPv6 addresses. - { true, "FEDC:ba98:7654:3210:FEDC:BA98:7654:3210" }, - { true, "::192.9.5.5" }, - { true, "FEED::BEEF" }, + { true, "99.64.0.0" }, + { true, "212.15.0.0" }, + { true, "212.15" }, + { true, "212.15.0" }, + { true, "3557752832" }, + // Reserved IPv4 addresses (in various forms). + { false, "192.168.0.0" }, + { false, "192.168.0.6" }, + { false, "10.0.0.5" }, + { false, "10.0" }, + { false, "10.0.0" }, + { false, "3232235526" }, + // Unreserved IPv6 addresses. + { true, "FFC0:ba98:7654:3210:FEDC:BA98:7654:3210" }, + { true, "2000:ba98:7654:2301:EFCD:BA98:7654:3210" }, + // Reserved IPv6 addresses. + { false, "::192.9.5.5" }, + { false, "FEED::BEEF" }, + { false, "FEC0:ba98:7654:3210:FEDC:BA98:7654:3210" }, // 'internal'/non-IANA assigned domains. { false, "intranet" }, { false, "intranet." }, diff --git a/chromium/net/base/network_change_notifier.cc b/chromium/net/base/network_change_notifier.cc index a6335d03dc2..fad9440c5bf 100644 --- a/chromium/net/base/network_change_notifier.cc +++ b/chromium/net/base/network_change_notifier.cc @@ -249,7 +249,7 @@ class HistogramWatcher // from the network thread. void NotifyDataReceived(const URLRequest& request, int bytes_read) { if (IsLocalhost(request.url().host()) || - !(request.url().SchemeIs("http") || request.url().SchemeIs("https"))) { + !request.url().SchemeIsHTTPOrHTTPS()) { return; } diff --git a/chromium/net/base/prioritized_dispatcher.cc b/chromium/net/base/prioritized_dispatcher.cc index 44348e6f5e3..b72f7a5018d 100644 --- a/chromium/net/base/prioritized_dispatcher.cc +++ b/chromium/net/base/prioritized_dispatcher.cc @@ -18,17 +18,7 @@ PrioritizedDispatcher::PrioritizedDispatcher(const Limits& limits) : queue_(limits.reserved_slots.size()), max_running_jobs_(limits.reserved_slots.size()), num_running_jobs_(0) { - size_t total = 0; - for (size_t i = 0; i < limits.reserved_slots.size(); ++i) { - total += limits.reserved_slots[i]; - max_running_jobs_[i] = total; - } - // Unreserved slots are available for all priorities. - DCHECK_LE(total, limits.total_jobs) << "sum(reserved_slots) <= total_jobs"; - size_t spare = limits.total_jobs - total; - for (size_t i = limits.reserved_slots.size(); i > 0; --i) { - max_running_jobs_[i - 1] += spare; - } + SetLimits(limits); } PrioritizedDispatcher::~PrioritizedDispatcher() {} @@ -45,6 +35,18 @@ PrioritizedDispatcher::Handle PrioritizedDispatcher::Add( return queue_.Insert(job, priority); } +PrioritizedDispatcher::Handle PrioritizedDispatcher::AddAtHead( + Job* job, Priority priority) { + DCHECK(job); + DCHECK_LT(priority, num_priorities()); + if (num_running_jobs_ < max_running_jobs_[priority]) { + ++num_running_jobs_; + job->Start(); + return Handle(); + } + return queue_.InsertAtFront(job, priority); +} + void PrioritizedDispatcher::Cancel(const Handle& handle) { queue_.Erase(handle); } @@ -78,12 +80,45 @@ PrioritizedDispatcher::Handle PrioritizedDispatcher::ChangePriority( void PrioritizedDispatcher::OnJobFinished() { DCHECK_GT(num_running_jobs_, 0u); --num_running_jobs_; - Handle handle = queue_.FirstMax(); - if (handle.is_null()) { - DCHECK_EQ(0u, queue_.size()); - return; + MaybeDispatchNextJob(); +} + +PrioritizedDispatcher::Limits PrioritizedDispatcher::GetLimits() const { + size_t num_priorities = max_running_jobs_.size(); + Limits limits(num_priorities, max_running_jobs_.back()); + + // Calculate the number of jobs reserved for each priority and higher. Leave + // the number of jobs reserved for the lowest priority or higher as 0. + for (size_t i = 1; i < num_priorities; ++i) { + limits.reserved_slots[i] = max_running_jobs_[i] - max_running_jobs_[i - 1]; + } + + return limits; +} + +void PrioritizedDispatcher::SetLimits(const Limits& limits) { + DCHECK_EQ(queue_.num_priorities(), limits.reserved_slots.size()); + size_t total = 0; + for (size_t i = 0; i < limits.reserved_slots.size(); ++i) { + total += limits.reserved_slots[i]; + max_running_jobs_[i] = total; + } + // Unreserved slots are available for all priorities. + DCHECK_LE(total, limits.total_jobs) << "sum(reserved_slots) <= total_jobs"; + size_t spare = limits.total_jobs - total; + for (size_t i = limits.reserved_slots.size(); i > 0; --i) { + max_running_jobs_[i - 1] += spare; + } + + // Start pending jobs, if limits permit. + while (true) { + if (!MaybeDispatchNextJob()) + break; } - MaybeDispatchJob(handle, handle.priority()); +} + +void PrioritizedDispatcher::SetLimitsToZero() { + SetLimits(Limits(queue_.num_priorities(), 0)); } bool PrioritizedDispatcher::MaybeDispatchJob(const Handle& handle, @@ -98,4 +133,13 @@ bool PrioritizedDispatcher::MaybeDispatchJob(const Handle& handle, return true; } +bool PrioritizedDispatcher::MaybeDispatchNextJob() { + Handle handle = queue_.FirstMax(); + if (handle.is_null()) { + DCHECK_EQ(0u, queue_.size()); + return false; + } + return MaybeDispatchJob(handle, handle.priority()); +} + } // namespace net diff --git a/chromium/net/base/prioritized_dispatcher.h b/chromium/net/base/prioritized_dispatcher.h index 708f9d6011d..e4f1019d76c 100644 --- a/chromium/net/base/prioritized_dispatcher.h +++ b/chromium/net/base/prioritized_dispatcher.h @@ -78,6 +78,10 @@ class NET_EXPORT_PRIVATE PrioritizedDispatcher { // it is queued in the dispatcher. Handle Add(Job* job, Priority priority); + // Just like Add, except that it adds Job at the font of queue of jobs with + // priorities of |priority|. + Handle AddAtHead(Job* job, Priority priority); + // Removes the job with |handle| from the queue. Invalidates |handle|. // Note: a Handle is valid iff the job is in the queue, i.e. has not Started. void Cancel(const Handle& handle); @@ -94,12 +98,29 @@ class NET_EXPORT_PRIVATE PrioritizedDispatcher { // Notifies the dispatcher that a running job has finished. Could start a job. void OnJobFinished(); + // Retrieves the Limits that |this| is currently using. This may not exactly + // match the Limits this was created with. In particular, the number of slots + // reserved for the lowest priority will always be 0, even if it was non-zero + // in the Limits passed to the constructor or to SetLimits. + Limits GetLimits() const; + + // Updates |max_running_jobs_| to match |limits|. Starts jobs if new limit + // allows. Does not stop jobs if the new limits are lower than the old ones. + void SetLimits(const Limits& limits); + + // Set the limits to zero for all priorities, allowing no new jobs to start. + void SetLimitsToZero(); + private: // Attempts to dispatch the job with |handle| at priority |priority| (might be // different than |handle.priority()|. Returns true if successful. If so // the |handle| becomes invalid. bool MaybeDispatchJob(const Handle& handle, Priority priority); + // Attempts to dispatch the next highest priority job in the queue. Returns + // true if successful, and all handles to that job become invalid. + bool MaybeDispatchNextJob(); + // Queue for jobs that need to wait for a spare slot. PriorityQueue<Job*> queue_; // Maximum total number of running jobs allowed after a job at a particular diff --git a/chromium/net/base/prioritized_dispatcher_unittest.cc b/chromium/net/base/prioritized_dispatcher_unittest.cc index 41a09c5cb92..cef455f3671 100644 --- a/chromium/net/base/prioritized_dispatcher_unittest.cc +++ b/chromium/net/base/prioritized_dispatcher_unittest.cc @@ -52,13 +52,17 @@ class PrioritizedDispatcherTest : public testing::Test { return handle_; } - void Add() { + void Add(bool at_head) { CHECK(handle_.is_null()); CHECK(!running_); size_t num_queued = dispatcher_->num_queued_jobs(); size_t num_running = dispatcher_->num_running_jobs(); - handle_ = dispatcher_->Add(this, priority_); + if (!at_head) { + handle_ = dispatcher_->Add(this, priority_); + } else { + handle_ = dispatcher_->AddAtHead(this, priority_); + } if (handle_.is_null()) { EXPECT_EQ(num_queued, dispatcher_->num_queued_jobs()); @@ -140,7 +144,14 @@ class PrioritizedDispatcherTest : public testing::Test { TestJob* AddJob(char data, Priority priority) { TestJob* job = new TestJob(dispatcher_.get(), data, priority, &log_); jobs_.push_back(job); - job->Add(); + job->Add(false); + return job; + } + + TestJob* AddJobAtHead(char data, Priority priority) { + TestJob* job = new TestJob(dispatcher_.get(), data, priority, &log_); + jobs_.push_back(job); + job->Add(true); return job; } @@ -156,6 +167,38 @@ class PrioritizedDispatcherTest : public testing::Test { ScopedVector<TestJob> jobs_; }; +TEST_F(PrioritizedDispatcherTest, GetLimits) { + // Set non-trivial initial limits. + PrioritizedDispatcher::Limits original_limits(NUM_PRIORITIES, 5); + original_limits.reserved_slots[HIGHEST] = 1; + original_limits.reserved_slots[LOW] = 2; + Prepare(original_limits); + + // Get current limits, make sure the original limits are returned. + PrioritizedDispatcher::Limits retrieved_limits = dispatcher_->GetLimits(); + ASSERT_EQ(original_limits.total_jobs, retrieved_limits.total_jobs); + ASSERT_EQ(NUM_PRIORITIES, retrieved_limits.reserved_slots.size()); + for (size_t priority = 0; priority < NUM_PRIORITIES; ++priority) { + EXPECT_EQ(original_limits.reserved_slots[priority], + retrieved_limits.reserved_slots[priority]); + } + + // Set new limits. + PrioritizedDispatcher::Limits new_limits(NUM_PRIORITIES, 6); + new_limits.reserved_slots[MEDIUM] = 3; + new_limits.reserved_slots[LOWEST] = 1; + Prepare(new_limits); + + // Get current limits, make sure the new limits are returned. + retrieved_limits = dispatcher_->GetLimits(); + ASSERT_EQ(new_limits.total_jobs, retrieved_limits.total_jobs); + ASSERT_EQ(NUM_PRIORITIES, retrieved_limits.reserved_slots.size()); + for (size_t priority = 0; priority < NUM_PRIORITIES; ++priority) { + EXPECT_EQ(new_limits.reserved_slots[priority], + retrieved_limits.reserved_slots[priority]); + } +} + TEST_F(PrioritizedDispatcherTest, AddAFIFO) { // Allow only one running job. PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); @@ -202,6 +245,33 @@ TEST_F(PrioritizedDispatcherTest, AddPriority) { Expect("a.c.d.b.e."); } +TEST_F(PrioritizedDispatcherTest, AddAtHead) { + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); + Prepare(limits); + + TestJob* job_a = AddJob('a', MEDIUM); + TestJob* job_b = AddJobAtHead('b', MEDIUM); + TestJob* job_c = AddJobAtHead('c', HIGHEST); + TestJob* job_d = AddJobAtHead('d', HIGHEST); + TestJob* job_e = AddJobAtHead('e', MEDIUM); + TestJob* job_f = AddJob('f', MEDIUM); + + ASSERT_TRUE(job_a->running()); + job_a->Finish(); + ASSERT_TRUE(job_d->running()); + job_d->Finish(); + ASSERT_TRUE(job_c->running()); + job_c->Finish(); + ASSERT_TRUE(job_e->running()); + job_e->Finish(); + ASSERT_TRUE(job_b->running()); + job_b->Finish(); + ASSERT_TRUE(job_f->running()); + job_f->Finish(); + + Expect("a.d.c.e.b.f."); +} + TEST_F(PrioritizedDispatcherTest, EnforceLimits) { // Reserve 2 for HIGHEST and 1 for LOW or higher. // This leaves 2 for LOWEST or lower. @@ -245,29 +315,40 @@ TEST_F(PrioritizedDispatcherTest, EnforceLimits) { } TEST_F(PrioritizedDispatcherTest, ChangePriority) { - PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 2); + // Reserve one slot only for HIGHEST priority requests. + limits.reserved_slots[HIGHEST] = 1; Prepare(limits); TestJob* job_a = AddJob('a', IDLE); - TestJob* job_b = AddJob('b', MEDIUM); - TestJob* job_c = AddJob('c', HIGHEST); - TestJob* job_d = AddJob('d', HIGHEST); + TestJob* job_b = AddJob('b', LOW); + TestJob* job_c = AddJob('c', MEDIUM); + TestJob* job_d = AddJob('d', MEDIUM); + TestJob* job_e = AddJob('e', IDLE); ASSERT_FALSE(job_b->running()); ASSERT_FALSE(job_c->running()); - job_b->ChangePriority(HIGHEST); - job_c->ChangePriority(MEDIUM); + job_b->ChangePriority(MEDIUM); + job_c->ChangePriority(LOW); ASSERT_TRUE(job_a->running()); job_a->Finish(); ASSERT_TRUE(job_d->running()); job_d->Finish(); + + EXPECT_FALSE(job_e->running()); + // Increasing |job_e|'s priority to HIGHEST should result in it being + // started immediately. + job_e->ChangePriority(HIGHEST); + ASSERT_TRUE(job_e->running()); + job_e->Finish(); + ASSERT_TRUE(job_b->running()); job_b->Finish(); ASSERT_TRUE(job_c->running()); job_c->Finish(); - Expect("a.d.b.c."); + Expect("a.d.be..c."); } TEST_F(PrioritizedDispatcherTest, Cancel) { @@ -324,6 +405,127 @@ TEST_F(PrioritizedDispatcherTest, EvictFromEmpty) { EXPECT_TRUE(dispatcher_->EvictOldestLowest() == NULL); } +TEST_F(PrioritizedDispatcherTest, AddWhileZeroLimits) { + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 2); + Prepare(limits); + + dispatcher_->SetLimitsToZero(); + TestJob* job_a = AddJob('a', LOW); + TestJob* job_b = AddJob('b', MEDIUM); + TestJob* job_c = AddJobAtHead('c', MEDIUM); + + EXPECT_EQ(0u, dispatcher_->num_running_jobs()); + EXPECT_EQ(3u, dispatcher_->num_queued_jobs()); + + dispatcher_->SetLimits(limits); + EXPECT_EQ(2u, dispatcher_->num_running_jobs()); + EXPECT_EQ(1u, dispatcher_->num_queued_jobs()); + + ASSERT_TRUE(job_b->running()); + job_b->Finish(); + + ASSERT_TRUE(job_c->running()); + job_c->Finish(); + + ASSERT_TRUE(job_a->running()); + job_a->Finish(); + + Expect("cb.a.."); +} + +TEST_F(PrioritizedDispatcherTest, ReduceLimitsWhileJobQueued) { + PrioritizedDispatcher::Limits initial_limits(NUM_PRIORITIES, 2); + Prepare(initial_limits); + + TestJob* job_a = AddJob('a', MEDIUM); + TestJob* job_b = AddJob('b', MEDIUM); + TestJob* job_c = AddJob('c', MEDIUM); + TestJob* job_d = AddJob('d', MEDIUM); + TestJob* job_e = AddJob('e', MEDIUM); + + EXPECT_EQ(2u, dispatcher_->num_running_jobs()); + EXPECT_EQ(3u, dispatcher_->num_queued_jobs()); + + // Reduce limits to just allow one job at a time. Running jobs should not + // be affected. + dispatcher_->SetLimits(PrioritizedDispatcher::Limits(NUM_PRIORITIES, 1)); + + EXPECT_EQ(2u, dispatcher_->num_running_jobs()); + EXPECT_EQ(3u, dispatcher_->num_queued_jobs()); + + // Finishing a job should not result in another job starting. + ASSERT_TRUE(job_a->running()); + job_a->Finish(); + EXPECT_EQ(1u, dispatcher_->num_running_jobs()); + EXPECT_EQ(3u, dispatcher_->num_queued_jobs()); + + ASSERT_TRUE(job_b->running()); + job_b->Finish(); + EXPECT_EQ(1u, dispatcher_->num_running_jobs()); + EXPECT_EQ(2u, dispatcher_->num_queued_jobs()); + + // Increasing the limits again should let c start. + dispatcher_->SetLimits(initial_limits); + + ASSERT_TRUE(job_c->running()); + job_c->Finish(); + ASSERT_TRUE(job_d->running()); + job_d->Finish(); + ASSERT_TRUE(job_e->running()); + job_e->Finish(); + + Expect("ab..cd.e.."); +} + +TEST_F(PrioritizedDispatcherTest, ZeroLimitsThenCancel) { + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); + Prepare(limits); + + TestJob* job_a = AddJob('a', IDLE); + TestJob* job_b = AddJob('b', IDLE); + TestJob* job_c = AddJob('c', IDLE); + dispatcher_->SetLimitsToZero(); + + ASSERT_TRUE(job_a->running()); + EXPECT_FALSE(job_b->running()); + EXPECT_FALSE(job_c->running()); + job_a->Finish(); + + EXPECT_FALSE(job_b->running()); + EXPECT_FALSE(job_c->running()); + + // Cancelling b shouldn't start job c. + job_b->Cancel(); + EXPECT_FALSE(job_c->running()); + + // Restoring the limits should start c. + dispatcher_->SetLimits(limits); + ASSERT_TRUE(job_c->running()); + job_c->Finish(); + + Expect("a.c."); +} + +TEST_F(PrioritizedDispatcherTest, ZeroLimitsThenIncreasePriority) { + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 2); + limits.reserved_slots[HIGHEST] = 1; + Prepare(limits); + + TestJob* job_a = AddJob('a', IDLE); + TestJob* job_b = AddJob('b', IDLE); + EXPECT_TRUE(job_a->running()); + EXPECT_FALSE(job_b->running()); + dispatcher_->SetLimitsToZero(); + + job_b->ChangePriority(HIGHEST); + EXPECT_FALSE(job_b->running()); + job_a->Finish(); + EXPECT_FALSE(job_b->running()); + + job_b->Cancel(); + Expect("a."); +} + #if GTEST_HAS_DEATH_TEST && !defined(NDEBUG) TEST_F(PrioritizedDispatcherTest, CancelNull) { PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); diff --git a/chromium/net/base/priority_queue.h b/chromium/net/base/priority_queue.h index b758ca45dea..c6845805354 100644 --- a/chromium/net/base/priority_queue.h +++ b/chromium/net/base/priority_queue.h @@ -139,6 +139,24 @@ class PriorityQueue : public base::NonThreadSafe { #endif } + // Adds |value| with |priority| to the queue. Returns a pointer to the + // created element. + Pointer InsertAtFront(const T& value, Priority priority) { + DCHECK(CalledOnValidThread()); + DCHECK_LT(priority, lists_.size()); + ++size_; + List& list = lists_[priority]; +#if !defined(NDEBUG) + unsigned id = next_id_; + valid_ids_.insert(id); + ++next_id_; + return Pointer(priority, list.insert(list.begin(), + std::make_pair(id, value))); +#else + return Pointer(priority, list.insert(list.begin(), value)); +#endif + } + // Removes the value pointed by |pointer| from the queue. All pointers to this // value including |pointer| become invalid. void Erase(const Pointer& pointer) { @@ -213,6 +231,9 @@ class PriorityQueue : public base::NonThreadSafe { size_ = 0u; } + // Returns the number of priorities the queue supports. + size_t num_priorities() const { return lists_.size(); } + // Returns number of queued values. size_t size() const { DCHECK(CalledOnValidThread()); diff --git a/chromium/net/base/priority_queue_unittest.cc b/chromium/net/base/priority_queue_unittest.cc index c8449bbd3fd..7e3e045fe19 100644 --- a/chromium/net/base/priority_queue_unittest.cc +++ b/chromium/net/base/priority_queue_unittest.cc @@ -91,7 +91,22 @@ TEST_F(PriorityQueueTest, EraseFromMiddle) { queue_.Erase(pointers_[2]); queue_.Erase(pointers_[3]); - int expected_order[] = { 8, 1, 6, 0, 5, 4, 7 }; + const int expected_order[] = { 8, 1, 6, 0, 5, 4, 7 }; + + for (size_t i = 0; i < arraysize(expected_order); ++i) { + EXPECT_EQ(expected_order[i], queue_.FirstMin().value()); + queue_.Erase(queue_.FirstMin()); + } + CheckEmpty(); +} + +TEST_F(PriorityQueueTest, InsertAtFront) { + queue_.InsertAtFront(9, 2); + queue_.InsertAtFront(10, 0); + queue_.InsertAtFront(11, 1); + queue_.InsertAtFront(12, 1); + + const int expected_order[] = { 10, 3, 8, 12, 11, 1, 6, 9, 0, 2, 5, 4, 7 }; for (size_t i = 0; i < arraysize(expected_order); ++i) { EXPECT_EQ(expected_order[i], queue_.FirstMin().value()); diff --git a/chromium/net/base/upload_data_stream_unittest.cc b/chromium/net/base/upload_data_stream_unittest.cc index 42b71b9cbea..e0cbc27c106 100644 --- a/chromium/net/base/upload_data_stream_unittest.cc +++ b/chromium/net/base/upload_data_stream_unittest.cc @@ -14,6 +14,7 @@ #include "base/files/scoped_temp_dir.h" #include "base/memory/scoped_ptr.h" #include "base/message_loop/message_loop.h" +#include "base/run_loop.h" #include "base/time/time.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -123,9 +124,14 @@ class MockUploadElementReader : public UploadElementReader { class UploadDataStreamTest : public PlatformTest { public: - virtual void SetUp() OVERRIDE { + virtual void SetUp() { + PlatformTest::SetUp(); ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); } + virtual ~UploadDataStreamTest() { + element_readers_.clear(); + base::RunLoop().RunUntilIdle(); + } void FileChangedHelper(const base::FilePath& file_path, const base::Time& time, diff --git a/chromium/net/base/upload_file_element_reader_unittest.cc b/chromium/net/base/upload_file_element_reader_unittest.cc index 8224f773046..b0435019e9c 100644 --- a/chromium/net/base/upload_file_element_reader_unittest.cc +++ b/chromium/net/base/upload_file_element_reader_unittest.cc @@ -7,6 +7,7 @@ #include "base/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/message_loop/message_loop_proxy.h" +#include "base/run_loop.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -17,7 +18,8 @@ namespace net { class UploadFileElementReaderTest : public PlatformTest { protected: - virtual void SetUp() OVERRIDE { + virtual void SetUp() { + PlatformTest::SetUp(); // Some tests (*.ReadPartially) rely on bytes_.size() being even. const char kData[] = "123456789abcdefghi"; bytes_.assign(kData, kData + arraysize(kData) - 1); @@ -44,6 +46,11 @@ class UploadFileElementReaderTest : public PlatformTest { EXPECT_FALSE(reader_->IsInMemory()); } + virtual ~UploadFileElementReaderTest() { + reader_.reset(); + base::RunLoop().RunUntilIdle(); + } + std::vector<char> bytes_; scoped_ptr<UploadElementReader> reader_; base::ScopedTempDir temp_dir_; diff --git a/chromium/net/base/winsock_util.cc b/chromium/net/base/winsock_util.cc index 5e5c312d385..5522e2719a5 100644 --- a/chromium/net/base/winsock_util.cc +++ b/chromium/net/base/winsock_util.cc @@ -31,8 +31,6 @@ void CheckEventWait(WSAEVENT hEvent, DWORD wait_rv, DWORD expected) { #pragma optimize( "", on ) #pragma warning(pop) -net::PlatformSocketFactory* g_socket_factory = NULL; - } // namespace void AssertEventNotSignaled(WSAEVENT hEvent) { @@ -51,15 +49,4 @@ bool ResetEventIfSignaled(WSAEVENT hEvent) { return true; } -void PlatformSocketFactory::SetInstance(PlatformSocketFactory* factory) { - g_socket_factory = factory; -} - -SOCKET CreatePlatformSocket(int family, int type, int protocol) { - if (g_socket_factory) - return g_socket_factory->CreateSocket(family, type, protocol); - else - return ::WSASocket(family, type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED); -} - } // namespace net diff --git a/chromium/net/base/winsock_util.h b/chromium/net/base/winsock_util.h index 06ac448a817..36d670b1a2d 100644 --- a/chromium/net/base/winsock_util.h +++ b/chromium/net/base/winsock_util.h @@ -24,26 +24,6 @@ void AssertEventNotSignaled(WSAEVENT hEvent); // optimization. The code still works if this function simply returns false. bool ResetEventIfSignaled(WSAEVENT hEvent); -// Interface to create Windows Socket. -// Usually such factories are used for testing purposes, which is not true in -// this case. This interface is used to substitute WSASocket to make possible -// execution of some network code in sandbox. -class NET_EXPORT PlatformSocketFactory { - public: - PlatformSocketFactory() {} - virtual ~PlatformSocketFactory() {} - - // Creates Windows socket. See WSASocket documentation of parameters. - virtual SOCKET CreateSocket(int family, int type, int protocol) = 0; - - // Replace WSASocket with given factory. The factory will be used by - // CreatePlatformSocket. - static void SetInstance(PlatformSocketFactory* factory); -}; - -// Creates Windows Socket. See WSASocket documentation of parameters. -SOCKET CreatePlatformSocket(int family, int type, int protocol); - } // namespace net #endif // NET_BASE_WINSOCK_UTIL_H_ diff --git a/chromium/net/cert/cert_type.h b/chromium/net/cert/cert_type.h index cb212274b12..84fc44ab1d5 100644 --- a/chromium/net/cert/cert_type.h +++ b/chromium/net/cert/cert_type.h @@ -16,7 +16,7 @@ namespace net { // UNKNOWN_CERT. If that cert is then trusted with SetCertTrust(cert, // SERVER_CERT, TRUSTED_SSL), it would become a SERVER_CERT. enum CertType { - UNKNOWN_CERT, + OTHER_CERT, CA_CERT, USER_CERT, SERVER_CERT, diff --git a/chromium/net/cert/cert_verify_proc.cc b/chromium/net/cert/cert_verify_proc.cc index ec1ef682b47..05f6c30b8f2 100644 --- a/chromium/net/cert/cert_verify_proc.cc +++ b/chromium/net/cert/cert_verify_proc.cc @@ -242,16 +242,19 @@ int CertVerifyProc::Verify(X509Certificate* cert, rv = MapCertStatusToNetError(verify_result->cert_status); } +#if !defined(OS_ANDROID) // Flag certificates from publicly-trusted CAs that are issued to intranet // hosts. While the CA/Browser Forum Baseline Requirements (v1.1) permit // these to be issued until 1 November 2015, they represent a real risk for // the deployment of gTLDs and are being phased out ahead of the hard // deadline. - // TODO(rsleevi): http://crbug.com/119212 - Also match internal IP address - // ranges. + // + // TODO(ppi): is_issued_by_known_root is incorrect on Android. Once this is + // fixed, re-enable this check for Android. crbug.com/116838 if (verify_result->is_issued_by_known_root && IsHostnameNonUnique(hostname)) { verify_result->cert_status |= CERT_STATUS_NON_UNIQUE_NAME; } +#endif return rv; } @@ -371,7 +374,7 @@ bool CertVerifyProc::IsPublicKeyBlacklisted( // in 2036, but we can probably remove in a couple of years (2014). {0xd9, 0xf5, 0xc6, 0xce, 0x57, 0xff, 0xaa, 0x39, 0xcc, 0x7e, 0xd1, 0x72, 0xbd, 0x53, 0xe0, 0xd3, 0x07, 0x83, 0x4b, 0xd1}, - // Win32/Sirefef.gen!C generates fake certifciates with this public key. + // Win32/Sirefef.gen!C generates fake certificates with this public key. {0xa4, 0xf5, 0x6e, 0x9e, 0x1d, 0x9a, 0x3b, 0x7b, 0x1a, 0xc3, 0x31, 0xcf, 0x64, 0xfc, 0x76, 0x2c, 0xd0, 0x51, 0xfb, 0xa4}, }; diff --git a/chromium/net/cert/cert_verify_proc_nss.cc b/chromium/net/cert/cert_verify_proc_nss.cc index f63297e83c5..0a0743c150d 100644 --- a/chromium/net/cert/cert_verify_proc_nss.cc +++ b/chromium/net/cert/cert_verify_proc_nss.cc @@ -764,8 +764,7 @@ int CertVerifyProcNSS::VerifyInternal( #endif // defined(OS_IOS) // Make sure that the hostname matches with the common name of the cert. - SECStatus status = CERT_VerifyCertName(cert_handle, hostname.c_str()); - if (status != SECSuccess) + if (!cert->VerifyNameMatch(hostname)) verify_result->cert_status |= CERT_STATUS_COMMON_NAME_INVALID; // Make sure that the cert is valid now. @@ -805,9 +804,9 @@ int CertVerifyProcNSS::VerifyInternal( CertificateListToCERTCertList(additional_trust_anchors)); } - status = PKIXVerifyCert(cert_handle, check_revocation, false, - cert_io_enabled, NULL, 0, trust_anchors.get(), - cvout); + SECStatus status = PKIXVerifyCert(cert_handle, check_revocation, false, + cert_io_enabled, NULL, 0, + trust_anchors.get(), cvout); if (status == SECSuccess && (flags & CertVerifier::VERIFY_REV_CHECKING_REQUIRED_LOCAL_ANCHORS) && diff --git a/chromium/net/cert/cert_verify_proc_unittest.cc b/chromium/net/cert/cert_verify_proc_unittest.cc index a53d10a0845..2f3afe06e6a 100644 --- a/chromium/net/cert/cert_verify_proc_unittest.cc +++ b/chromium/net/cert/cert_verify_proc_unittest.cc @@ -6,6 +6,7 @@ #include <vector> +#include "base/callback_helpers.h" #include "base/files/file_path.h" #include "base/logging.h" #include "base/sha1.h" @@ -245,7 +246,6 @@ TEST_F(CertVerifyProcTest, MAYBE_IntermediateCARequireExplicitPolicy) { EXPECT_EQ(0u, verify_result.cert_status); } - // Test for bug 58437. // This certificate will expire on 2011-12-21. The test will still // pass if error == ERR_CERT_DATE_INVALID. @@ -692,11 +692,19 @@ TEST_F(CertVerifyProcTest, VerifyReturnChainBasic) { certs[2]->os_cert_handle())); } +#if defined(OS_ANDROID) +// TODO(ppi): Disabled because is_issued_by_known_root is incorrect on Android. +// Once this is fixed, re-enable this check for android. crbug.com/116838 +#define MAYBE_IntranetHostsRejected DISABLED_IntranetHostsRejected +#else +#define MAYBE_IntranetHostsRejected IntranetHostsRejected +#endif + // Test that certificates issued for 'intranet' names (that is, containing no // known public registry controlled domain information) issued by well-known // CAs are flagged appropriately, while certificates that are issued by // internal CAs are not flagged. -TEST_F(CertVerifyProcTest, IntranetHostsRejected) { +TEST_F(CertVerifyProcTest, MAYBE_IntranetHostsRejected) { CertificateList cert_list = CreateCertificateListFromFile( GetTestCertsDirectory(), "ok_cert.pem", X509Certificate::FORMAT_AUTO); @@ -1356,4 +1364,69 @@ WRAPPED_INSTANTIATE_TEST_CASE_P( CertVerifyProcWeakDigestTest, testing::ValuesIn(kVerifyMixedTestData)); +// For the list of valid hostnames, see +// net/cert/data/ssl/certificates/subjectAltName_sanity_check.pem +static const struct CertVerifyProcNameData { + const char* hostname; + bool valid; // Whether or not |hostname| matches a subjectAltName. +} kVerifyNameData[] = { + { "127.0.0.1", false }, // Don't match the common name + { "127.0.0.2", true }, // Matches the iPAddress SAN (IPv4) + { "FE80:0:0:0:0:0:0:1", true }, // Matches the iPAddress SAN (IPv6) + { "[FE80:0:0:0:0:0:0:1]", false }, // Should not match the iPAddress SAN + { "FE80::1", true }, // Compressed form matches the iPAddress SAN (IPv6) + { "::127.0.0.2", false }, // IPv6 mapped form should NOT match iPAddress SAN + { "test.example", true }, // Matches the dNSName SAN + { "test.example.", true }, // Matches the dNSName SAN (trailing . ignored) + { "www.test.example", false }, // Should not match the dNSName SAN + { "test..example", false }, // Should not match the dNSName SAN + { "test.example..", false }, // Should not match the dNSName SAN + { ".test.example.", false }, // Should not match the dNSName SAN + { ".test.example", false }, // Should not match the dNSName SAN +}; + +// GTest 'magic' pretty-printer, so that if/when a test fails, it knows how +// to output the parameter that was passed. Without this, it will simply +// attempt to print out the first twenty bytes of the object, which depending +// on platform and alignment, may result in an invalid read. +void PrintTo(const CertVerifyProcNameData& data, std::ostream* os) { + *os << "Hostname: " << data.hostname << "; valid=" << data.valid; +} + +class CertVerifyProcNameTest + : public CertVerifyProcTest, + public testing::WithParamInterface<CertVerifyProcNameData> { + public: + CertVerifyProcNameTest() {} + virtual ~CertVerifyProcNameTest() {} +}; + +TEST_P(CertVerifyProcNameTest, VerifyCertName) { + CertVerifyProcNameData data = GetParam(); + + CertificateList cert_list = CreateCertificateListFromFile( + GetTestCertsDirectory(), "subjectAltName_sanity_check.pem", + X509Certificate::FORMAT_AUTO); + ASSERT_EQ(1U, cert_list.size()); + scoped_refptr<X509Certificate> cert(cert_list[0]); + + ScopedTestRoot scoped_root(cert.get()); + + CertVerifyResult verify_result; + int error = Verify(cert.get(), data.hostname, 0, NULL, empty_cert_list_, + &verify_result); + if (data.valid) { + EXPECT_EQ(OK, error); + EXPECT_FALSE(verify_result.cert_status & CERT_STATUS_COMMON_NAME_INVALID); + } else { + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, error); + EXPECT_TRUE(verify_result.cert_status & CERT_STATUS_COMMON_NAME_INVALID); + } +} + +WRAPPED_INSTANTIATE_TEST_CASE_P( + VerifyName, + CertVerifyProcNameTest, + testing::ValuesIn(kVerifyNameData)); + } // namespace net diff --git a/chromium/net/cert/cert_verify_proc_win.cc b/chromium/net/cert/cert_verify_proc_win.cc index 7e94246af96..b64797a1d6f 100644 --- a/chromium/net/cert/cert_verify_proc_win.cc +++ b/chromium/net/cert/cert_verify_proc_win.cc @@ -647,6 +647,7 @@ int CertVerifyProcWin::VerifyInternal( chain_flags &= ~CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY; verify_result->cert_status |= CERT_STATUS_REV_CHECKING_ENABLED; + CertFreeCertificateChain(chain_context); if (!CertGetCertificateChain( chain_engine, cert_list.get(), @@ -727,7 +728,10 @@ int CertVerifyProcWin::VerifyInternal( memset(&extra_policy_para, 0, sizeof(extra_policy_para)); extra_policy_para.cbSize = sizeof(extra_policy_para); extra_policy_para.dwAuthType = AUTHTYPE_SERVER; - extra_policy_para.fdwChecks = 0; + // Certificate name validation happens separately, later, using an internal + // routine that has better support for RFC 6125 name matching. + extra_policy_para.fdwChecks = + 0x00001000; // SECURITY_FLAG_IGNORE_CERT_CN_INVALID extra_policy_para.pwszServerName = const_cast<wchar_t*>(wstr_hostname.c_str()); @@ -752,57 +756,17 @@ int CertVerifyProcWin::VerifyInternal( if (policy_status.dwError) { verify_result->cert_status |= MapNetErrorToCertStatus( MapSecurityError(policy_status.dwError)); - - // CertVerifyCertificateChainPolicy reports only one error (in - // policy_status.dwError) if the certificate has multiple errors. - // CertGetCertificateChain doesn't report certificate name mismatch, so - // CertVerifyCertificateChainPolicy is the only function that can report - // certificate name mismatch. - // - // To prevent a potential certificate name mismatch from being hidden by - // some other certificate error, if we get any other certificate error, - // we call CertVerifyCertificateChainPolicy again, ignoring all other - // certificate errors. Both extra_policy_para.fdwChecks and - // policy_para.dwFlags allow us to ignore certificate errors, so we set - // them both. - if (policy_status.dwError != CERT_E_CN_NO_MATCH) { - const DWORD extra_ignore_flags = - 0x00000080 | // SECURITY_FLAG_IGNORE_REVOCATION - 0x00000100 | // SECURITY_FLAG_IGNORE_UNKNOWN_CA - 0x00002000 | // SECURITY_FLAG_IGNORE_CERT_DATE_INVALID - 0x00000200; // SECURITY_FLAG_IGNORE_WRONG_USAGE - extra_policy_para.fdwChecks = extra_ignore_flags; - const DWORD ignore_flags = - CERT_CHAIN_POLICY_IGNORE_ALL_NOT_TIME_VALID_FLAGS | - CERT_CHAIN_POLICY_IGNORE_INVALID_BASIC_CONSTRAINTS_FLAG | - CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG | - CERT_CHAIN_POLICY_IGNORE_WRONG_USAGE_FLAG | - CERT_CHAIN_POLICY_IGNORE_INVALID_NAME_FLAG | - CERT_CHAIN_POLICY_IGNORE_INVALID_POLICY_FLAG | - CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS | - CERT_CHAIN_POLICY_ALLOW_TESTROOT_FLAG | - CERT_CHAIN_POLICY_TRUST_TESTROOT_FLAG | - CERT_CHAIN_POLICY_IGNORE_NOT_SUPPORTED_CRITICAL_EXT_FLAG | - CERT_CHAIN_POLICY_IGNORE_PEER_TRUST_FLAG; - policy_para.dwFlags = ignore_flags; - if (!CertVerifyCertificateChainPolicy( - CERT_CHAIN_POLICY_SSL, - chain_context, - &policy_para, - &policy_status)) { - return MapSecurityError(GetLastError()); - } - if (policy_status.dwError) { - verify_result->cert_status |= MapNetErrorToCertStatus( - MapSecurityError(policy_status.dwError)); - } - } } // TODO(wtc): Suppress CERT_STATUS_NO_REVOCATION_MECHANISM for now to be // compatible with WinHTTP, which doesn't report this error (bug 3004). verify_result->cert_status &= ~CERT_STATUS_NO_REVOCATION_MECHANISM; + // Perform hostname verification independent of + // CertVerifyCertificateChainPolicy. + if (!cert->VerifyNameMatch(hostname)) + verify_result->cert_status |= CERT_STATUS_COMMON_NAME_INVALID; + if (!rev_checking_enabled) { // If we didn't do online revocation checking then Windows will report // CERT_UNABLE_TO_CHECK_REVOCATION unless it had cached OCSP or CRL diff --git a/chromium/net/cert/nss_cert_database_unittest.cc b/chromium/net/cert/nss_cert_database_unittest.cc index 2e712bb54ad..4ad1192e7af 100644 --- a/chromium/net/cert/nss_cert_database_unittest.cc +++ b/chromium/net/cert/nss_cert_database_unittest.cc @@ -68,7 +68,7 @@ class CertDatabaseNSSTest : public testing::Test { static std::string ReadTestFile(const std::string& name) { std::string result; base::FilePath cert_path = GetTestCertsDirectory().AppendASCII(name); - EXPECT_TRUE(file_util::ReadFileToString(cert_path, &result)); + EXPECT_TRUE(base::ReadFileToString(cert_path, &result)); return result; } diff --git a/chromium/net/cert/test_root_certs.cc b/chromium/net/cert/test_root_certs.cc index 3219f1d84c1..f4958901685 100644 --- a/chromium/net/cert/test_root_certs.cc +++ b/chromium/net/cert/test_root_certs.cc @@ -22,7 +22,7 @@ base::LazyInstance<TestRootCerts>::Leaky CertificateList LoadCertificates(const base::FilePath& filename) { std::string raw_cert; - if (!file_util::ReadFileToString(filename, &raw_cert)) { + if (!base::ReadFileToString(filename, &raw_cert)) { LOG(ERROR) << "Can't load certificate " << filename.value(); return CertificateList(); } diff --git a/chromium/net/cert/x509_cert_types.h b/chromium/net/cert/x509_cert_types.h index b6adb518aa9..f74c82eab7b 100644 --- a/chromium/net/cert/x509_cert_types.h +++ b/chromium/net/cert/x509_cert_types.h @@ -42,7 +42,7 @@ struct NET_EXPORT CertPrincipal { bool ParseDistinguishedName(const void* ber_name_data, size_t length); #endif -#if defined(OS_MACOSX) +#if defined(OS_MACOSX) && !defined(OS_IOS) // Compare this CertPrincipal with |against|, returning true if they're // equal enough to be a possible match. This should NOT be used for any // security relevant decisions. @@ -136,9 +136,9 @@ enum CertDateFormat { // |format|, and writes the result into |*time|. If an invalid date is // specified, or if parsing fails, returns false, and |*time| will not be // updated. -bool ParseCertificateDate(const base::StringPiece& raw_date, - CertDateFormat format, - base::Time* time); +NET_EXPORT_PRIVATE bool ParseCertificateDate(const base::StringPiece& raw_date, + CertDateFormat format, + base::Time* time); } // namespace net #endif // NET_CERT_X509_CERT_TYPES_H_ diff --git a/chromium/net/cert/x509_cert_types_unittest.cc b/chromium/net/cert/x509_cert_types_unittest.cc index e0bcc707dc1..38fd3e95266 100644 --- a/chromium/net/cert/x509_cert_types_unittest.cc +++ b/chromium/net/cert/x509_cert_types_unittest.cc @@ -2,14 +2,19 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/basictypes.h" #include "net/cert/x509_cert_types.h" + +#include "base/basictypes.h" +#include "base/strings/string_piece.h" +#include "base/time/time.h" #include "net/test/test_certificate_data.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { -#if defined(OS_MACOSX) +namespace { + +#if defined(OS_MACOSX) && !defined(OS_IOS) TEST(X509TypesTest, Matching) { CertPrincipal spamco; spamco.common_name = "SpamCo Dept. Of Certificization"; @@ -48,6 +53,7 @@ TEST(X509TypesTest, Matching) { } #endif +#if (defined(OS_MACOSX) && !defined(OS_IOS)) || defined(OS_WIN) TEST(X509TypesTest, ParseDNVerisign) { CertPrincipal verisign; EXPECT_TRUE(verisign.ParseDistinguishedName(VerisignDN, sizeof(VerisignDN))); @@ -135,5 +141,103 @@ TEST(X509TypesTest, ParseDNEntrust) { EXPECT_EQ("(c) 1999 Entrust.net Limited", entrust.organization_unit_names[1]); } +#endif + +const struct CertDateTestData { + CertDateFormat format; + const char* date_string; + bool is_valid; + base::Time::Exploded expected_result; +} kCertDateTimeData[] = { + { CERT_DATE_FORMAT_UTC_TIME, + "120101000000Z", + true, + { 2012, 1, 0, 1, 0, 0, 0 } }, + { CERT_DATE_FORMAT_GENERALIZED_TIME, + "20120101000000Z", + true, + { 2012, 1, 0, 1, 0, 0, 0 } }, + { CERT_DATE_FORMAT_UTC_TIME, + "490101000000Z", + true, + { 2049, 1, 0, 1, 0, 0, 0 } }, + { CERT_DATE_FORMAT_UTC_TIME, + "500101000000Z", + true, + { 1950, 1, 0, 1, 0, 0, 0 } }, + { CERT_DATE_FORMAT_GENERALIZED_TIME, + "19500101000000Z", + true, + { 1950, 1, 0, 1, 0, 0, 0 } }, + { CERT_DATE_FORMAT_UTC_TIME, + "AB0101000000Z", + false, + { 0 } }, + { CERT_DATE_FORMAT_GENERALIZED_TIME, + "19AB0101000000Z", + false, + { 0 } }, + { CERT_DATE_FORMAT_UTC_TIME, + "", + false, + { 0 } }, + { CERT_DATE_FORMAT_UTC_TIME, + "A", + false, + { 0 } }, + { CERT_DATE_FORMAT_GENERALIZED_TIME, + "20121301000000Z", + false, + { 0 } }, + { CERT_DATE_FORMAT_GENERALIZED_TIME, + "20120101123000Z", + true, + { 2012, 1, 0, 1, 12, 30, 0 } }, +}; + +// GTest pretty printer. +void PrintTo(const CertDateTestData& data, std::ostream* os) { + *os << " format: " << data.format + << "; date string: " << base::StringPiece(data.date_string) + << "; valid: " << data.is_valid + << "; expected date: " + << (data.is_valid ? + base::Time::FromUTCExploded(data.expected_result) + .ToInternalValue() : + 0U); +} + +class X509CertTypesDateTest : public testing::TestWithParam<CertDateTestData> { + public: + virtual ~X509CertTypesDateTest() {} + virtual void SetUp() { + test_data_ = GetParam(); + } + + protected: + CertDateTestData test_data_; +}; + +TEST_P(X509CertTypesDateTest, Parse) { + base::Time parsed_date; + bool parsed = ParseCertificateDate( + test_data_.date_string, test_data_.format, &parsed_date); + EXPECT_EQ(test_data_.is_valid, parsed); + if (!test_data_.is_valid) + return; + // Convert the expected value to a base::Time(). This ensures that systems + // systems that only support 32-bit times will pass the tests, by ensuring at + // least that the times have the same truncating behaviour. + // Note: Compared as internal values so that mismatches can be cleanly + // printed by GTest (eg: without PrintTo overrides). + EXPECT_EQ(base::Time::FromUTCExploded(test_data_.expected_result) + .ToInternalValue(), + parsed_date.ToInternalValue()); +} +INSTANTIATE_TEST_CASE_P(, + X509CertTypesDateTest, + testing::ValuesIn(kCertDateTimeData)); + +} // namespace } // namespace net diff --git a/chromium/net/cert/x509_certificate_mac.cc b/chromium/net/cert/x509_certificate_mac.cc index 2f8ce438afd..96f7b52ab3d 100644 --- a/chromium/net/cert/x509_certificate_mac.cc +++ b/chromium/net/cert/x509_certificate_mac.cc @@ -283,9 +283,9 @@ class ScopedEncodedCertResults { for (uint32 i = 0; i < results_->NumberOfResults; i++) { crypto::CSSMFree(encCert[i].CertBlob.Data); } + crypto::CSSMFree(results_->Results); + crypto::CSSMFree(results_); } - crypto::CSSMFree(results_->Results); - crypto::CSSMFree(results_); } private: diff --git a/chromium/net/cert/x509_certificate_nss.cc b/chromium/net/cert/x509_certificate_nss.cc index ca256504817..9e95413a22e 100644 --- a/chromium/net/cert/x509_certificate_nss.cc +++ b/chromium/net/cert/x509_certificate_nss.cc @@ -106,7 +106,7 @@ std::string X509Certificate::GetDefaultNickname(CertType type) const { case SERVER_CERT: result = subject_.GetDisplayName(); break; - case UNKNOWN_CERT: + case OTHER_CERT: default: break; } diff --git a/chromium/net/cert/x509_certificate_openssl.cc b/chromium/net/cert/x509_certificate_openssl.cc index bdf2bf20538..71d558ddc84 100644 --- a/chromium/net/cert/x509_certificate_openssl.cc +++ b/chromium/net/cert/x509_certificate_openssl.cc @@ -456,7 +456,7 @@ void X509Certificate::GetPublicKeyInfo(OSCertHandle cert_handle, break; case EVP_PKEY_EC: *type = kPublicKeyTypeECDSA; - *size_bits = EVP_PKEY_size(key); + *size_bits = EVP_PKEY_bits(key); break; case EVP_PKEY_DH: *type = kPublicKeyTypeDH; diff --git a/chromium/net/cert/x509_certificate_unittest.cc b/chromium/net/cert/x509_certificate_unittest.cc index 75ba827a18a..04dbaddf69f 100644 --- a/chromium/net/cert/x509_certificate_unittest.cc +++ b/chromium/net/cert/x509_certificate_unittest.cc @@ -23,6 +23,10 @@ #include <cert.h> #endif +#if defined(OS_WIN) +#include "base/win/windows_version.h" +#endif + using base::HexEncode; using base::Time; @@ -84,75 +88,6 @@ const double kGoogleParseValidFrom = 1261094400; // Dec 18 23:59:59 2011 GMT const double kGoogleParseValidTo = 1324252799; -struct CertificateFormatTestData { - const char* file_name; - X509Certificate::Format format; - uint8* chain_fingerprints[3]; -}; - -const CertificateFormatTestData FormatTestData[] = { - // DER Parsing - single certificate, DER encoded - { "google.single.der", X509Certificate::FORMAT_SINGLE_CERTIFICATE, - { google_parse_fingerprint, - NULL, } }, - // DER parsing - single certificate, PEM encoded - { "google.single.pem", X509Certificate::FORMAT_SINGLE_CERTIFICATE, - { google_parse_fingerprint, - NULL, } }, - // PEM parsing - single certificate, PEM encoded with a PEB of - // "CERTIFICATE" - { "google.single.pem", X509Certificate::FORMAT_PEM_CERT_SEQUENCE, - { google_parse_fingerprint, - NULL, } }, - // PEM parsing - sequence of certificates, PEM encoded with a PEB of - // "CERTIFICATE" - { "google.chain.pem", X509Certificate::FORMAT_PEM_CERT_SEQUENCE, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - // PKCS#7 parsing - "degenerate" SignedData collection of certificates, DER - // encoding - { "google.binary.p7b", X509Certificate::FORMAT_PKCS7, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - // PKCS#7 parsing - "degenerate" SignedData collection of certificates, PEM - // encoded with a PEM PEB of "CERTIFICATE" - { "google.pem_cert.p7b", X509Certificate::FORMAT_PKCS7, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - // PKCS#7 parsing - "degenerate" SignedData collection of certificates, PEM - // encoded with a PEM PEB of "PKCS7" - { "google.pem_pkcs7.p7b", X509Certificate::FORMAT_PKCS7, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - // All of the above, this time using auto-detection - { "google.single.der", X509Certificate::FORMAT_AUTO, - { google_parse_fingerprint, - NULL, } }, - { "google.single.pem", X509Certificate::FORMAT_AUTO, - { google_parse_fingerprint, - NULL, } }, - { "google.chain.pem", X509Certificate::FORMAT_AUTO, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - { "google.binary.p7b", X509Certificate::FORMAT_AUTO, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - { "google.pem_cert.p7b", X509Certificate::FORMAT_AUTO, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, - { "google.pem_pkcs7.p7b", X509Certificate::FORMAT_AUTO, - { google_parse_fingerprint, - thawte_parse_fingerprint, - NULL, } }, -}; - void CheckGoogleCert(const scoped_refptr<X509Certificate>& google_cert, uint8* expected_fingerprint, double valid_from, double valid_to) { @@ -869,6 +804,73 @@ TEST(X509CertificateTest, GetDefaultNickname) { } #endif +const struct CertificateFormatTestData { + const char* file_name; + X509Certificate::Format format; + uint8* chain_fingerprints[3]; +} kFormatTestData[] = { + // DER Parsing - single certificate, DER encoded + { "google.single.der", X509Certificate::FORMAT_SINGLE_CERTIFICATE, + { google_parse_fingerprint, + NULL, } }, + // DER parsing - single certificate, PEM encoded + { "google.single.pem", X509Certificate::FORMAT_SINGLE_CERTIFICATE, + { google_parse_fingerprint, + NULL, } }, + // PEM parsing - single certificate, PEM encoded with a PEB of + // "CERTIFICATE" + { "google.single.pem", X509Certificate::FORMAT_PEM_CERT_SEQUENCE, + { google_parse_fingerprint, + NULL, } }, + // PEM parsing - sequence of certificates, PEM encoded with a PEB of + // "CERTIFICATE" + { "google.chain.pem", X509Certificate::FORMAT_PEM_CERT_SEQUENCE, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + // PKCS#7 parsing - "degenerate" SignedData collection of certificates, DER + // encoding + { "google.binary.p7b", X509Certificate::FORMAT_PKCS7, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + // PKCS#7 parsing - "degenerate" SignedData collection of certificates, PEM + // encoded with a PEM PEB of "CERTIFICATE" + { "google.pem_cert.p7b", X509Certificate::FORMAT_PKCS7, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + // PKCS#7 parsing - "degenerate" SignedData collection of certificates, PEM + // encoded with a PEM PEB of "PKCS7" + { "google.pem_pkcs7.p7b", X509Certificate::FORMAT_PKCS7, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + // All of the above, this time using auto-detection + { "google.single.der", X509Certificate::FORMAT_AUTO, + { google_parse_fingerprint, + NULL, } }, + { "google.single.pem", X509Certificate::FORMAT_AUTO, + { google_parse_fingerprint, + NULL, } }, + { "google.chain.pem", X509Certificate::FORMAT_AUTO, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + { "google.binary.p7b", X509Certificate::FORMAT_AUTO, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + { "google.pem_cert.p7b", X509Certificate::FORMAT_AUTO, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, + { "google.pem_pkcs7.p7b", X509Certificate::FORMAT_AUTO, + { google_parse_fingerprint, + thawte_parse_fingerprint, + NULL, } }, +}; + class X509CertificateParseTest : public testing::TestWithParam<CertificateFormatTestData> { public: @@ -915,7 +917,7 @@ TEST_P(X509CertificateParseTest, CanParseFormat) { } INSTANTIATE_TEST_CASE_P(, X509CertificateParseTest, - testing::ValuesIn(FormatTestData)); + testing::ValuesIn(kFormatTestData)); struct CertificateNameVerifyTestData { // true iff we expect hostname to match an entry in cert_names. @@ -1144,4 +1146,50 @@ TEST_P(X509CertificateNameVerifyTest, VerifyHostname) { INSTANTIATE_TEST_CASE_P(, X509CertificateNameVerifyTest, testing::ValuesIn(kNameVerifyTestData)); +const struct PublicKeyInfoTestData { + const char* cert_file; + size_t expected_bits; + X509Certificate::PublicKeyType expected_type; +} kPublicKeyInfoTestData[] = { + { "768-rsa-ee-by-768-rsa-intermediate.pem", 768, + X509Certificate::kPublicKeyTypeRSA }, + { "1024-rsa-ee-by-768-rsa-intermediate.pem", 1024, + X509Certificate::kPublicKeyTypeRSA }, + { "prime256v1-ecdsa-ee-by-1024-rsa-intermediate.pem", 256, + X509Certificate::kPublicKeyTypeECDSA }, +}; + +class X509CertificatePublicKeyInfoTest + : public testing::TestWithParam<PublicKeyInfoTestData> { +}; + +TEST_P(X509CertificatePublicKeyInfoTest, GetPublicKeyInfo) { + PublicKeyInfoTestData data = GetParam(); + +#if defined(OS_WIN) + if (base::win::GetVersion() < base::win::VERSION_VISTA && + data.expected_type == X509Certificate::kPublicKeyTypeECDSA) { + // ECC is only supported on Vista+. Skip the test. + return; + } +#endif + + scoped_refptr<X509Certificate> cert( + ImportCertFromFile(GetTestCertsDirectory(), data.cert_file)); + ASSERT_TRUE(cert.get()); + + size_t actual_bits = 0; + X509Certificate::PublicKeyType actual_type = + X509Certificate::kPublicKeyTypeUnknown; + + X509Certificate::GetPublicKeyInfo(cert->os_cert_handle(), &actual_bits, + &actual_type); + + EXPECT_EQ(data.expected_bits, actual_bits); + EXPECT_EQ(data.expected_type, actual_type); +} + +INSTANTIATE_TEST_CASE_P(, X509CertificatePublicKeyInfoTest, + testing::ValuesIn(kPublicKeyInfoTestData)); + } // namespace net diff --git a/chromium/net/cookies/cookie_monster.cc b/chromium/net/cookies/cookie_monster.cc index f24637735fd..65723ff20a6 100644 --- a/chromium/net/cookies/cookie_monster.cc +++ b/chromium/net/cookies/cookie_monster.cc @@ -112,6 +112,15 @@ const int CookieMonster::kSafeFromGlobalPurgeDays = 30; namespace { +bool ContainsControlCharacter(const std::string& s) { + for (std::string::const_iterator i = s.begin(); i != s.end(); ++i) { + if ((*i >= 0) && (*i <= 31)) + return true; + } + + return false; +} + typedef std::vector<CanonicalCookie*> CanonicalCookieVector; // Default minimum delay after updating a cookie's LastAccessDate before we @@ -286,6 +295,8 @@ ChangeCausePair ChangeCauseMapping[] = { { CookieMonster::Delegate::CHANGE_COOKIE_EVICTED, true }, // DELETE_COOKIE_EXPIRED_OVERWRITE { CookieMonster::Delegate::CHANGE_COOKIE_EXPIRED_OVERWRITE, true }, + // DELETE_COOKIE_CONTROL_CHAR + { CookieMonster::Delegate::CHANGE_COOKIE_EVICTED, true}, // DELETE_COOKIE_LAST_ENTRY { CookieMonster::Delegate::CHANGE_COOKIE_EXPLICIT, false } }; @@ -1477,16 +1488,24 @@ void CookieMonster::StoreLoadedCookies( // and sync'd. base::AutoLock autolock(lock_); + CookieItVector cookies_with_control_chars; + for (std::vector<CanonicalCookie*>::const_iterator it = cookies.begin(); it != cookies.end(); ++it) { int64 cookie_creation_time = (*it)->CreationDate().ToInternalValue(); if (creation_times_.insert(cookie_creation_time).second) { - InternalInsertCookie(GetKey((*it)->Domain()), *it, false); + CookieMap::iterator inserted = + InternalInsertCookie(GetKey((*it)->Domain()), *it, false); const Time cookie_access_time((*it)->LastAccessDate()); if (earliest_access_time_.is_null() || cookie_access_time < earliest_access_time_) earliest_access_time_ = cookie_access_time; + + if (ContainsControlCharacter((*it)->Name()) || + ContainsControlCharacter((*it)->Value())) { + cookies_with_control_chars.push_back(inserted); + } } else { LOG(ERROR) << base::StringPrintf("Found cookies with duplicate creation " "times in backing store: " @@ -1500,6 +1519,16 @@ void CookieMonster::StoreLoadedCookies( } } + // Any cookies that contain control characters that we have loaded from the + // persistent store should be deleted. See http://crbug.com/238041. + for (CookieItVector::iterator it = cookies_with_control_chars.begin(); + it != cookies_with_control_chars.end();) { + CookieItVector::iterator curit = it; + ++it; + + InternalDeleteCookie(*curit, true, DELETE_COOKIE_CONTROL_CHAR); + } + // After importing cookies from the PersistentCookieStore, verify that // none of our other constraints are violated. // In particular, the backing store might have given us duplicate cookies. @@ -1733,19 +1762,23 @@ bool CookieMonster::DeleteAnyEquivalentCookie(const std::string& key, return skipped_httponly; } -void CookieMonster::InternalInsertCookie(const std::string& key, - CanonicalCookie* cc, - bool sync_to_store) { +CookieMonster::CookieMap::iterator CookieMonster::InternalInsertCookie( + const std::string& key, + CanonicalCookie* cc, + bool sync_to_store) { lock_.AssertAcquired(); if ((cc->IsPersistent() || persist_session_cookies_) && store_.get() && sync_to_store) store_->AddCookie(*cc); - cookies_.insert(CookieMap::value_type(key, cc)); + CookieMap::iterator inserted = + cookies_.insert(CookieMap::value_type(key, cc)); if (delegate_.get()) { delegate_->OnCookieChanged( *cc, false, CookieMonster::Delegate::CHANGE_COOKIE_EXPLICIT); } + + return inserted; } bool CookieMonster::SetCookieWithCreationTimeAndOptions( @@ -1831,6 +1864,8 @@ void CookieMonster::InternalUpdateCookieAccessTime(CanonicalCookie* cc, store_->UpdateCookieAccessTime(*cc); } +// InternalDeleteCookies must not invalidate iterators other than the one being +// deleted. void CookieMonster::InternalDeleteCookie(CookieMap::iterator it, bool sync_to_store, DeletionCause deletion_cause) { diff --git a/chromium/net/cookies/cookie_monster.h b/chromium/net/cookies/cookie_monster.h index eaf89d33810..1df616cd6d9 100644 --- a/chromium/net/cookies/cookie_monster.h +++ b/chromium/net/cookies/cookie_monster.h @@ -379,6 +379,11 @@ class NET_EXPORT CookieMonster : public CookieStore { // already-expired expiration date. This captures that case. DELETE_COOKIE_EXPIRED_OVERWRITE, + // Cookies are not allowed to contain control characters in the name or + // value. However, we used to allow them, so we are now evicting any such + // cookies as we load them. See http://crbug.com/238041. + DELETE_COOKIE_CONTROL_CHAR, + DELETE_COOKIE_LAST_ENTRY }; @@ -517,10 +522,11 @@ class NET_EXPORT CookieMonster : public CookieStore { bool skip_httponly, bool already_expired); - // Takes ownership of *cc. - void InternalInsertCookie(const std::string& key, - CanonicalCookie* cc, - bool sync_to_store); + // Takes ownership of *cc. Returns an iterator that points to the inserted + // cookie in cookies_. Guarantee: all iterators to cookies_ remain valid. + CookieMap::iterator InternalInsertCookie(const std::string& key, + CanonicalCookie* cc, + bool sync_to_store); // Helper function that sets cookies with more control. // Not exposed as we don't want callers to have the ability @@ -541,6 +547,8 @@ class NET_EXPORT CookieMonster : public CookieStore { // |deletion_cause| argument is used for collecting statistics and choosing // the correct Delegate::ChangeCause for OnCookieChanged notifications. + // Guarantee: All iterators to cookies_ except to the deleted entry remain + // vaild. void InternalDeleteCookie(CookieMap::iterator it, bool sync_to_store, DeletionCause deletion_cause); diff --git a/chromium/net/cookies/cookie_monster_perftest.cc b/chromium/net/cookies/cookie_monster_perftest.cc index c70516f70eb..2bc0be8480c 100644 --- a/chromium/net/cookies/cookie_monster_perftest.cc +++ b/chromium/net/cookies/cookie_monster_perftest.cc @@ -6,9 +6,9 @@ #include "base/bind.h" #include "base/message_loop/message_loop.h" -#include "base/perftimer.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" +#include "base/test/perf_time_logger.h" #include "net/cookies/canonical_cookie.h" #include "net/cookies/cookie_monster.h" #include "net/cookies/cookie_monster_store_test.h" @@ -97,7 +97,7 @@ class GetCookiesCallback : public BaseCallback { TEST(ParsedCookieTest, TestParseCookies) { std::string cookie(kCookieLine); - PerfTimeLogger timer("Parsed_cookie_parse_cookies"); + base::PerfTimeLogger timer("Parsed_cookie_parse_cookies"); for (int i = 0; i < kNumCookies; ++i) { ParsedCookie pc(cookie); EXPECT_TRUE(pc.IsValid()); @@ -108,7 +108,7 @@ TEST(ParsedCookieTest, TestParseCookies) { TEST(ParsedCookieTest, TestParseBigCookies) { std::string cookie(3800, 'z'); cookie += kCookieLine; - PerfTimeLogger timer("Parsed_cookie_parse_big_cookies"); + base::PerfTimeLogger timer("Parsed_cookie_parse_big_cookies"); for (int i = 0; i < kNumCookies; ++i) { ParsedCookie pc(cookie); EXPECT_TRUE(pc.IsValid()); @@ -126,7 +126,7 @@ TEST_F(CookieMonsterTest, TestAddCookiesOnSingleHost) { SetCookieCallback setCookieCallback; // Add a bunch of cookies on a single host - PerfTimeLogger timer("Cookie_monster_add_single_host"); + base::PerfTimeLogger timer("Cookie_monster_add_single_host"); for (std::vector<std::string>::const_iterator it = cookies.begin(); it != cookies.end(); ++it) { @@ -136,14 +136,14 @@ TEST_F(CookieMonsterTest, TestAddCookiesOnSingleHost) { GetCookiesCallback getCookiesCallback; - PerfTimeLogger timer2("Cookie_monster_query_single_host"); + base::PerfTimeLogger timer2("Cookie_monster_query_single_host"); for (std::vector<std::string>::const_iterator it = cookies.begin(); it != cookies.end(); ++it) { getCookiesCallback.GetCookies(cm.get(), GURL(kGoogleURL)); } timer2.Done(); - PerfTimeLogger timer3("Cookie_monster_deleteall_single_host"); + base::PerfTimeLogger timer3("Cookie_monster_deleteall_single_host"); cm->DeleteAllAsync(CookieMonster::DeleteCallback()); base::MessageLoop::current()->RunUntilIdle(); timer3.Done(); @@ -160,7 +160,7 @@ TEST_F(CookieMonsterTest, TestAddCookieOnManyHosts) { SetCookieCallback setCookieCallback; // Add a cookie on a bunch of host - PerfTimeLogger timer("Cookie_monster_add_many_hosts"); + base::PerfTimeLogger timer("Cookie_monster_add_many_hosts"); for (std::vector<GURL>::const_iterator it = gurls.begin(); it != gurls.end(); ++it) { setCookieCallback.SetCookie(cm.get(), *it, cookie); @@ -169,14 +169,14 @@ TEST_F(CookieMonsterTest, TestAddCookieOnManyHosts) { GetCookiesCallback getCookiesCallback; - PerfTimeLogger timer2("Cookie_monster_query_many_hosts"); + base::PerfTimeLogger timer2("Cookie_monster_query_many_hosts"); for (std::vector<GURL>::const_iterator it = gurls.begin(); it != gurls.end(); ++it) { getCookiesCallback.GetCookies(cm.get(), *it); } timer2.Done(); - PerfTimeLogger timer3("Cookie_monster_deleteall_many_hosts"); + base::PerfTimeLogger timer3("Cookie_monster_deleteall_many_hosts"); cm->DeleteAllAsync(CookieMonster::DeleteCallback()); base::MessageLoop::current()->RunUntilIdle(); timer3.Done(); @@ -229,7 +229,7 @@ TEST_F(CookieMonsterTest, TestDomainTree) { std::string cookie_line = getCookiesCallback.GetCookies(cm.get(), probe_gurl); EXPECT_EQ(5, CountInString(cookie_line, '=')) << "Cookie line: " << cookie_line; - PerfTimeLogger timer("Cookie_monster_query_domain_tree"); + base::PerfTimeLogger timer("Cookie_monster_query_domain_tree"); for (int i = 0; i < kNumCookies; i++) { getCookiesCallback.GetCookies(cm.get(), probe_gurl); } @@ -269,7 +269,7 @@ TEST_F(CookieMonsterTest, TestDomainLine) { cookie_line = getCookiesCallback.GetCookies(cm.get(), probe_gurl); EXPECT_EQ(32, CountInString(cookie_line, '=')); - PerfTimeLogger timer2("Cookie_monster_query_domain_line"); + base::PerfTimeLogger timer2("Cookie_monster_query_domain_line"); for (int i = 0; i < kNumCookies; i++) { getCookiesCallback.GetCookies(cm.get(), probe_gurl); } @@ -304,7 +304,7 @@ TEST_F(CookieMonsterTest, TestImport) { // Import will happen on first access. GURL gurl("www.google.com"); CookieOptions options; - PerfTimeLogger timer("Cookie_monster_import_from_store"); + base::PerfTimeLogger timer("Cookie_monster_import_from_store"); getCookiesCallback.GetCookies(cm.get(), gurl); timer.Done(); @@ -314,7 +314,7 @@ TEST_F(CookieMonsterTest, TestImport) { TEST_F(CookieMonsterTest, TestGetKey) { scoped_refptr<CookieMonster> cm(new CookieMonster(NULL, NULL)); - PerfTimeLogger timer("Cookie_monster_get_key"); + base::PerfTimeLogger timer("Cookie_monster_get_key"); for (int i = 0; i < kNumCookies; i++) cm->GetKey("www.google.com"); timer.Done(); @@ -375,7 +375,7 @@ TEST_F(CookieMonsterTest, TestGCTimes) { // Trigger the Garbage collection we're allowed. setCookieCallback.SetCookie(cm.get(), gurl, cookie_line); - PerfTimeLogger timer((std::string("GC_") + test_case.name).c_str()); + base::PerfTimeLogger timer((std::string("GC_") + test_case.name).c_str()); for (int i = 0; i < kNumCookies; i++) setCookieCallback.SetCookie(cm.get(), gurl, cookie_line); timer.Done(); diff --git a/chromium/net/cookies/cookie_monster_unittest.cc b/chromium/net/cookies/cookie_monster_unittest.cc index d1ce04f3885..2bfe9ee9978 100644 --- a/chromium/net/cookies/cookie_monster_unittest.cc +++ b/chromium/net/cookies/cookie_monster_unittest.cc @@ -2685,4 +2685,45 @@ TEST_F(CookieMonsterTest, PersisentCookieStorageTest) { EXPECT_EQ(5u, store->commands().size()); } +// Test to assure that cookies with control characters are purged appropriately. +// See http://crbug.com/238041 for background. +TEST_F(CookieMonsterTest, ControlCharacterPurge) { + const Time now1(Time::Now()); + const Time now2(Time::Now() + TimeDelta::FromSeconds(1)); + const Time now3(Time::Now() + TimeDelta::FromSeconds(2)); + const Time later(now1 + TimeDelta::FromDays(1)); + const GURL url("http://host/path"); + const std::string domain("host"); + const std::string path("/path"); + + scoped_refptr<MockPersistentCookieStore> store( + new MockPersistentCookieStore); + + std::vector<CanonicalCookie*> initial_cookies; + + AddCookieToList(domain, + "foo=bar; path=" + path, + now1, + &initial_cookies); + + // We have to manually build this cookie because it contains a control + // character, and our cookie line parser rejects control characters. + CanonicalCookie *cc = new CanonicalCookie(url, "baz", "\x05" "boo", domain, + path, now2, later, now2, false, + false, COOKIE_PRIORITY_DEFAULT); + initial_cookies.push_back(cc); + + AddCookieToList(domain, + "hello=world; path=" + path, + now3, + &initial_cookies); + + // Inject our initial cookies into the mock PersistentCookieStore. + store->SetLoadExpectation(true, initial_cookies); + + scoped_refptr<CookieMonster> cm(new CookieMonster(store.get(), NULL)); + + EXPECT_EQ("foo=bar; hello=world", GetCookies(cm.get(), url)); +} + } // namespace net diff --git a/chromium/net/cookies/parsed_cookie.cc b/chromium/net/cookies/parsed_cookie.cc index 125d3d998b9..60e0bbb6618 100644 --- a/chromium/net/cookies/parsed_cookie.cc +++ b/chromium/net/cookies/parsed_cookie.cc @@ -45,20 +45,8 @@ #include "net/cookies/parsed_cookie.h" #include "base/logging.h" -#include "base/metrics/histogram.h" #include "base/strings/string_util.h" -// TODO(jww): We are collecting several UMA statistics in this file, and they -// relate to http://crbug.com/238041. We are measuring stats related to control -// characters in cookies because, currently, we allow control characters in a -// variety of scenarios where various RFCs theoretically disallow them. These -// control characters have the potential to cause problems with certain web -// servers that reject HTTP requests that contain cookies with control -// characters. We are measuring whether disallowing such cookies would have a -// notable impact on our users. We want to collect these stats through 1 stable -// release, so these UMA stats should remain at least through the M29 -// branch-point. - namespace { const char kPathTokenName[] = "path"; @@ -148,11 +136,15 @@ bool IsValidCookieValue(const std::string& value) { return true; } +bool IsControlCharacter(unsigned char c) { + return (c >= 0) && (c <= 31); +} + bool IsValidCookieAttributeValue(const std::string& value) { // The greatest common denominator of cookie attribute values is // <any CHAR except CTLs or ";"> according to RFC 6265. for (std::string::const_iterator i = value.begin(); i != value.end(); ++i) { - if ((*i >= 0 && *i <= 31) || *i == ';') + if (IsControlCharacter(*i) || *i == ';') return false; } return true; @@ -194,9 +186,7 @@ CookiePriority ParsedCookie::Priority() const { } bool ParsedCookie::SetName(const std::string& name) { - bool valid_token = IsValidToken(name); - UMA_HISTOGRAM_BOOLEAN("Cookie.SetNameVaildity", valid_token); - if (!valid_token) + if (!IsValidToken(name)) return false; if (pairs_.empty()) pairs_.push_back(std::make_pair("", "")); @@ -205,10 +195,7 @@ bool ParsedCookie::SetName(const std::string& name) { } bool ParsedCookie::SetValue(const std::string& value) { - bool valid_cookie_value = IsValidCookieValue(value); - UMA_HISTOGRAM_BOOLEAN("Cookie.SetValueCookieValueValidity", - valid_cookie_value); - if (!valid_cookie_value) + if (!IsValidCookieValue(value)) return false; if (pairs_.empty()) pairs_.push_back(std::make_pair("", "")); @@ -354,15 +341,6 @@ std::string ParsedCookie::ParseValueString(const std::string& value) { // Parse all token/value pairs and populate pairs_. void ParsedCookie::ParseTokenValuePairs(const std::string& cookie_line) { - enum ParsedCookieStatus { - PARSED_COOKIE_STATUS_NOTHING = 0x0, - PARSED_COOKIE_STATUS_CONTROL_CHAR = 0x1, - PARSED_COOKIE_STATUS_INVALID = 0x2, - PARSED_COOKIE_STATUS_BOTH = - PARSED_COOKIE_STATUS_CONTROL_CHAR | PARSED_COOKIE_STATUS_INVALID - }; - int parsed_cookie_status = PARSED_COOKIE_STATUS_NOTHING; - pairs_.clear(); // Ok, here we go. We should be expecting to be starting somewhere @@ -407,17 +385,21 @@ void ParsedCookie::ParseTokenValuePairs(const std::string& cookie_line) { // OK, now try to parse a value. std::string::const_iterator value_start, value_end; ParseValue(&it, end, &value_start, &value_end); + // OK, we're finished with a Token/Value. pair.second = std::string(value_start, value_end); - if (!IsValidCookieAttributeValue(pair.second)) - parsed_cookie_status |= PARSED_COOKIE_STATUS_CONTROL_CHAR; - if (!IsValidToken(pair.second)) - parsed_cookie_status |= PARSED_COOKIE_STATUS_INVALID; - // From RFC2109: "Attributes (names) (attr) are case-insensitive." if (pair_num != 0) StringToLowerASCII(&pair.first); + // Ignore Set-Cookie directives contaning control characters. See + // http://crbug.com/238041. + if (!IsValidCookieAttributeValue(pair.first) || + !IsValidCookieAttributeValue(pair.second)) { + pairs_.clear(); + break; + } + pairs_.push_back(pair); // We've processed a token/value pair, we're either at the end of @@ -425,9 +407,6 @@ void ParsedCookie::ParseTokenValuePairs(const std::string& cookie_line) { if (it != end) ++it; } - - UMA_HISTOGRAM_ENUMERATION("Cookie.ParsedCookieStatus", parsed_cookie_status, - PARSED_COOKIE_STATUS_BOTH + 1); } void ParsedCookie::SetupAttributes() { @@ -478,11 +457,7 @@ bool ParsedCookie::SetBool(size_t* index, bool ParsedCookie::SetAttributePair(size_t* index, const std::string& key, const std::string& value) { - bool valid_attribute_pair = IsValidToken(key) && - IsValidCookieAttributeValue(value); - UMA_HISTOGRAM_BOOLEAN("Cookie.SetAttributePairCharsValidity", - valid_attribute_pair); - if (!valid_attribute_pair) + if (!(IsValidToken(key) && IsValidCookieAttributeValue(value))) return false; if (!IsValid()) return false; diff --git a/chromium/net/cookies/parsed_cookie_unittest.cc b/chromium/net/cookies/parsed_cookie_unittest.cc index ad4aba65d79..23e3768afe6 100644 --- a/chromium/net/cookies/parsed_cookie_unittest.cc +++ b/chromium/net/cookies/parsed_cookie_unittest.cc @@ -422,4 +422,73 @@ TEST(ParsedCookieTest, SetPriority) { EXPECT_EQ(COOKIE_PRIORITY_DEFAULT, pc.Priority()); } -} // namespace net +TEST(ParsedCookieTest, InvalidNonAlphanumericChars) { + ParsedCookie pc1("name=\x05"); + ParsedCookie pc2("name=foo" "\x1c" "bar"); + ParsedCookie pc3("name=foobar" "\x11"); + ParsedCookie pc4("name=\x02" "foobar"); + + ParsedCookie pc5("\x05=value"); + ParsedCookie pc6("foo" "\x05" "bar=value"); + ParsedCookie pc7("foobar" "\x05" "=value"); + ParsedCookie pc8("\x05" "foobar" "=value"); + + ParsedCookie pc9("foo" "\x05" "bar" "=foo" "\x05" "bar"); + + ParsedCookie pc10("foo=bar;ba" "\x05" "z=boo"); + ParsedCookie pc11("foo=bar;baz=bo" "\x05" "o"); + ParsedCookie pc12("foo=bar;ba" "\05" "z=bo" "\x05" "o"); + + EXPECT_FALSE(pc1.IsValid()); + EXPECT_FALSE(pc2.IsValid()); + EXPECT_FALSE(pc3.IsValid()); + EXPECT_FALSE(pc4.IsValid()); + EXPECT_FALSE(pc5.IsValid()); + EXPECT_FALSE(pc6.IsValid()); + EXPECT_FALSE(pc7.IsValid()); + EXPECT_FALSE(pc8.IsValid()); + EXPECT_FALSE(pc9.IsValid()); + EXPECT_FALSE(pc10.IsValid()); + EXPECT_FALSE(pc11.IsValid()); + EXPECT_FALSE(pc12.IsValid()); +} + +TEST(ParsedCookieTest, ValidNonAlphanumericChars) { + // Note that some of these words are pasted backwords thanks to poor vim bidi + // support. This should not affect the tests, however. + const char* pc1_literal = "name=العربية"; + const char* pc2_literal = "name=普通話"; + const char* pc3_literal = "name=ภาษาไทย"; + const char* pc4_literal = "name=עִבְרִית"; + const char* pc5_literal = "العربية=value"; + const char* pc6_literal = "普通話=value"; + const char* pc7_literal = "ภาษาไทย=value"; + const char* pc8_literal = "עִבְרִית=value"; + ParsedCookie pc1(pc1_literal); + ParsedCookie pc2(pc2_literal); + ParsedCookie pc3(pc3_literal); + ParsedCookie pc4(pc4_literal); + ParsedCookie pc5(pc5_literal); + ParsedCookie pc6(pc6_literal); + ParsedCookie pc7(pc7_literal); + ParsedCookie pc8(pc8_literal); + + EXPECT_TRUE(pc1.IsValid()); + EXPECT_EQ(pc1_literal, pc1.ToCookieLine()); + EXPECT_TRUE(pc2.IsValid()); + EXPECT_EQ(pc2_literal, pc2.ToCookieLine()); + EXPECT_TRUE(pc3.IsValid()); + EXPECT_EQ(pc3_literal, pc3.ToCookieLine()); + EXPECT_TRUE(pc4.IsValid()); + EXPECT_EQ(pc4_literal, pc4.ToCookieLine()); + EXPECT_TRUE(pc5.IsValid()); + EXPECT_EQ(pc5_literal, pc5.ToCookieLine()); + EXPECT_TRUE(pc6.IsValid()); + EXPECT_EQ(pc6_literal, pc6.ToCookieLine()); + EXPECT_TRUE(pc7.IsValid()); + EXPECT_EQ(pc7_literal, pc7.ToCookieLine()); + EXPECT_TRUE(pc8.IsValid()); + EXPECT_EQ(pc8_literal, pc8.ToCookieLine()); +} + +} diff --git a/chromium/net/data/ssl/certificates/subjectAltName_sanity_check.pem b/chromium/net/data/ssl/certificates/subjectAltName_sanity_check.pem index 46cf58de0cb..bb7f31b4218 100644 --- a/chromium/net/data/ssl/certificates/subjectAltName_sanity_check.pem +++ b/chromium/net/data/ssl/certificates/subjectAltName_sanity_check.pem @@ -1,54 +1,55 @@ Certificate: Data: Version: 3 (0x2) - Serial Number: - f2:f1:e7:8b:cf:09:30:f1 - Signature Algorithm: sha1WithRSAEncryption + Serial Number: 17778064637999560130 (0xf6b85f9895e5b5c2) + Signature Algorithm: sha1WithRSAEncryption Issuer: C=US, ST=California, L=Mountain View, O=Test CA, CN=127.0.0.1 Validity - Not Before: Apr 3 00:46:54 2012 GMT - Not After : Apr 1 00:46:54 2022 GMT + Not Before: Aug 16 02:31:34 2013 GMT + Not After : Aug 14 02:31:34 2023 GMT Subject: C=US, ST=California, L=Mountain View, O=Test CA, CN=127.0.0.1 Subject Public Key Info: Public Key Algorithm: rsaEncryption - RSA Public Key: (1024 bit) - Modulus (1024 bit): - 00:c8:0e:13:bb:da:d5:5a:d4:68:a2:11:90:ae:c3: - b3:f9:72:52:7d:e9:73:5c:49:60:ef:d3:49:05:9a: - c7:4e:01:4f:b0:c8:4c:18:34:2f:7b:84:27:ad:94: - 12:9b:e7:3d:38:6b:49:15:55:f6:c7:3a:8d:03:ec: - 3e:59:90:5c:b9:a6:41:af:f0:12:b8:87:b9:54:4d: - 1e:18:ba:41:96:d0:f3:bb:a0:d6:80:8e:29:10:72: - eb:3c:4c:c0:e2:f7:d8:61:2f:d8:63:c7:a7:79:f5: - 74:e0:2a:f0:5d:3e:eb:a2:36:09:4b:5d:35:31:56: - 1c:86:0e:8a:22:ad:1b:3f:27 + Public-Key: (1024 bit) + Modulus: + 00:bf:11:d3:18:37:84:53:8b:07:d3:7d:0a:dc:f7: + fc:ed:ce:8d:72:3a:29:af:17:e2:2b:d0:99:5f:3c: + 7b:29:a9:a8:3d:02:42:19:82:0b:df:5d:95:ac:60: + d9:08:69:ed:90:36:42:57:39:87:4c:cc:1e:8a:1d: + 7e:92:bb:7e:02:df:02:80:48:3f:38:21:cc:e9:d1: + b5:34:01:8f:92:17:ed:97:1d:11:2b:dd:df:fc:74: + f4:d6:66:9f:e3:e5:10:ea:ea:53:b2:a7:78:4b:96: + 31:06:38:0b:fa:0f:d8:58:9b:ff:2a:1f:2d:8c:ae: + 6c:42:73:4c:d2:cf:1b:b7:d1 Exponent: 65537 (0x10001) X509v3 extensions: + X509v3 Basic Constraints: critical + CA:TRUE X509v3 Subject Alternative Name: IP Address:127.0.0.2, IP Address:FE80:0:0:0:0:0:0:1, DNS:test.example, email:test@test.example, othername:<unsupported>, DirName:/CN=127.0.0.3 Signature Algorithm: sha1WithRSAEncryption - 32:46:49:70:be:e4:db:05:0e:7e:7a:e4:ea:5c:90:c6:4c:65: - 2d:03:ac:fb:d1:de:e4:26:e5:83:dc:5a:c8:4f:ff:b5:10:4e: - 39:21:7f:c8:37:f3:c6:7a:de:96:b3:30:e7:c7:87:6d:75:1e: - 14:30:17:6b:d2:76:0b:b8:43:39:c4:63:4c:50:8e:e1:0f:09: - ff:6c:7d:ab:c8:97:46:e8:04:70:9d:f5:e5:8c:b6:8c:b7:3d: - 8e:0f:59:1f:6a:fd:03:c2:be:a1:40:b7:9b:38:ca:55:f5:18: - c3:0d:35:01:12:a0:8d:ba:1b:41:a3:6e:68:8c:cf:52:f9:96: - 90:64 + ad:99:a8:25:29:15:1f:b8:c7:27:f0:c8:d7:2a:2a:66:54:07: + 2b:2c:b4:1e:fe:27:07:29:da:22:3d:7a:d8:4d:81:72:78:3e: + 96:5d:4c:42:ce:8c:c5:d1:d9:b3:ac:92:99:19:e5:2a:32:8a: + bc:ce:fb:58:a0:b9:e7:4b:44:d8:0c:2c:30:2f:fa:6c:48:7e: + 23:77:4f:67:e9:72:83:39:22:6f:2b:4d:25:16:3d:98:be:01: + 31:a0:55:0a:85:78:b8:b9:9c:66:e6:cb:7b:81:1c:fc:84:d1: + 79:1b:41:12:21:f8:c9:5b:fd:3c:a6:e4:6d:36:5b:0a:4c:aa: + bb:2b -----BEGIN CERTIFICATE----- -MIICsDCCAhmgAwIBAgIJAPLx54vPCTDxMA0GCSqGSIb3DQEBBQUAMGAxCzAJBgNV +MIICwzCCAiygAwIBAgIJAPa4X5iV5bXCMA0GCSqGSIb3DQEBBQUAMGAxCzAJBgNV BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRYwFAYDVQQHDA1Nb3VudGFpbiBW -aWV3MRAwDgYDVQQKDAdUZXN0IENBMRIwEAYDVQQDDAkxMjcuMC4wLjEwHhcNMTIw -NDAzMDA0NjU0WhcNMjIwNDAxMDA0NjU0WjBgMQswCQYDVQQGEwJVUzETMBEGA1UE +aWV3MRAwDgYDVQQKDAdUZXN0IENBMRIwEAYDVQQDDAkxMjcuMC4wLjEwHhcNMTMw +ODE2MDIzMTM0WhcNMjMwODE0MDIzMTM0WjBgMQswCQYDVQQGEwJVUzETMBEGA1UE CAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNTW91bnRhaW4gVmlldzEQMA4GA1UECgwH VGVzdCBDQTESMBAGA1UEAwwJMTI3LjAuMC4xMIGfMA0GCSqGSIb3DQEBAQUAA4GN -ADCBiQKBgQDIDhO72tVa1GiiEZCuw7P5clJ96XNcSWDv00kFmsdOAU+wyEwYNC97 -hCetlBKb5z04a0kVVfbHOo0D7D5ZkFy5pkGv8BK4h7lUTR4YukGW0PO7oNaAjikQ -cus8TMDi99hhL9hjx6d59XTgKvBdPuuiNglLXTUxVhyGDooirRs/JwIDAQABo3Iw -cDBuBgNVHREEZzBlhwR/AAAChxD+gAAAAAAAAAAAAAAAAAABggx0ZXN0LmV4YW1w -bGWBEXRlc3RAdGVzdC5leGFtcGxloBIGAyoDBKALDAlpZ25vcmUgbWWkFjAUMRIw -EAYDVQQDDAkxMjcuMC4wLjMwDQYJKoZIhvcNAQEFBQADgYEAMkZJcL7k2wUOfnrk -6lyQxkxlLQOs+9He5Cblg9xayE//tRBOOSF/yDfzxnrelrMw58eHbXUeFDAXa9J2 -C7hDOcRjTFCO4Q8J/2x9q8iXRugEcJ315Yy2jLc9jg9ZH2r9A8K+oUC3mzjKVfUY -ww01ARKgjbobQaNuaIzPUvmWkGQ= +ADCBiQKBgQC/EdMYN4RTiwfTfQrc9/ztzo1yOimvF+Ir0JlfPHspqag9AkIZggvf +XZWsYNkIae2QNkJXOYdMzB6KHX6Su34C3wKASD84Iczp0bU0AY+SF+2XHREr3d/8 +dPTWZp/j5RDq6lOyp3hLljEGOAv6D9hYm/8qHy2MrmxCc0zSzxu30QIDAQABo4GE +MIGBMA8GA1UdEwEB/wQFMAMBAf8wbgYDVR0RBGcwZYcEfwAAAocQ/oAAAAAAAAAA +AAAAAAAAAYIMdGVzdC5leGFtcGxlgRF0ZXN0QHRlc3QuZXhhbXBsZaASBgMqAwSg +CwwJaWdub3JlIG1lpBYwFDESMBAGA1UEAwwJMTI3LjAuMC4zMA0GCSqGSIb3DQEB +BQUAA4GBAK2ZqCUpFR+4xyfwyNcqKmZUBysstB7+Jwcp2iI9ethNgXJ4PpZdTELO +jMXR2bOskpkZ5SoyirzO+1iguedLRNgMLDAv+mxIfiN3T2fpcoM5Im8rTSUWPZi+ +ATGgVQqFeLi5nGbmy3uBHPyE0XkbQRIh+Mlb/Tym5G02WwpMqrsr -----END CERTIFICATE----- diff --git a/chromium/net/data/ssl/scripts/ee.cnf b/chromium/net/data/ssl/scripts/ee.cnf index ad786c80ca8..5214f9e97f2 100644 --- a/chromium/net/data/ssl/scripts/ee.cnf +++ b/chromium/net/data/ssl/scripts/ee.cnf @@ -29,7 +29,8 @@ CN = Duplicate subjectAltName = IP:127.0.0.1 [req_san_sanity] -subjectAltName = @san_sanity +basicConstraints = critical, CA:true +subjectAltName = @san_sanity [san_sanity] IP.1 = 127.0.0.2 diff --git a/chromium/net/data/url_request_unittest/redirect-test.html.mock-http-headers b/chromium/net/data/url_request_unittest/redirect-test.html.mock-http-headers index 9fdd1c0b7b9..c59ef583ab0 100644 --- a/chromium/net/data/url_request_unittest/redirect-test.html.mock-http-headers +++ b/chromium/net/data/url_request_unittest/redirect-test.html.mock-http-headers @@ -1,2 +1,3 @@ HTTP/1.1 302 Redirect Location: with-headers.html +Cache-Control: max-age=10000 diff --git a/chromium/net/data/url_request_unittest/redirect-to-data.html b/chromium/net/data/url_request_unittest/redirect-to-data.html new file mode 100644 index 00000000000..ce013625030 --- /dev/null +++ b/chromium/net/data/url_request_unittest/redirect-to-data.html @@ -0,0 +1 @@ +hello diff --git a/chromium/net/data/url_request_unittest/redirect-to-data.html.mock-http-headers b/chromium/net/data/url_request_unittest/redirect-to-data.html.mock-http-headers new file mode 100644 index 00000000000..f049471e064 --- /dev/null +++ b/chromium/net/data/url_request_unittest/redirect-to-data.html.mock-http-headers @@ -0,0 +1,2 @@ +HTTP/1.1 302 Here I Am +Location: data:text/html,goodbye diff --git a/chromium/net/disk_cache/backend_impl.cc b/chromium/net/disk_cache/backend_impl.cc index 8d7fd461102..0f8c3fdd195 100644 --- a/chromium/net/disk_cache/backend_impl.cc +++ b/chromium/net/disk_cache/backend_impl.cc @@ -337,8 +337,6 @@ void BackendImpl::CleanupCache() { // This is a net_unittest, verify that we are not 'leaking' entries. File::WaitForPendingIO(&num_pending_io_); DCHECK(!num_refs_); - } else { - File::DropPendingIO(); } } block_files_.CloseFiles(); @@ -869,7 +867,7 @@ int32 BackendImpl::GetCurrentEntryId() const { } int BackendImpl::MaxFileSize() const { - return max_size_ / 8; + return cache_type() == net::PNACL_CACHE ? max_size_ : max_size_ / 8; } void BackendImpl::ModifyStorageSize(int32 old_size, int32 new_size) { diff --git a/chromium/net/disk_cache/backend_unittest.cc b/chromium/net/disk_cache/backend_unittest.cc index bc48a2eb203..6ccd1e0224a 100644 --- a/chromium/net/disk_cache/backend_unittest.cc +++ b/chromium/net/disk_cache/backend_unittest.cc @@ -66,6 +66,23 @@ scoped_ptr<disk_cache::BackendImpl> CreateExistingEntryCache( // Tests that can run with different types of caches. class DiskCacheBackendTest : public DiskCacheTestWithCache { protected: + // Some utility methods: + + // Perform IO operations on the cache until there is pending IO. + int GeneratePendingIO(net::TestCompletionCallback* cb); + + // Adds 5 sparse entries. |doomed_start| and |doomed_end| if not NULL, + // will be filled with times, used by DoomEntriesSince and DoomEntriesBetween. + // There are 4 entries after doomed_start and 2 after doomed_end. + void InitSparseCache(base::Time* doomed_start, base::Time* doomed_end); + + bool CreateSetOfRandomEntries(std::set<std::string>* key_pool); + bool EnumerateAndMatchKeys(int max_to_open, + void** iter, + std::set<std::string>* keys_to_match, + size_t* count); + + // Actual tests: void BackendBasics(); void BackendKeying(); void BackendShutdownWithPendingFileIO(bool fast); @@ -85,12 +102,6 @@ class DiskCacheBackendTest : public DiskCacheTestWithCache { void BackendInvalidEntryEnumeration(); void BackendFixEnumerators(); void BackendDoomRecent(); - - // Adds 5 sparse entries. |doomed_start| and |doomed_end| if not NULL, - // will be filled with times, used by DoomEntriesSince and DoomEntriesBetween. - // There are 4 entries after doomed_start and 2 after doomed_end. - void InitSparseCache(base::Time* doomed_start, base::Time* doomed_end); - void BackendDoomBetween(); void BackendTransaction(const std::string& name, int num_entries, bool load); void BackendRecoverInsert(); @@ -113,14 +124,154 @@ class DiskCacheBackendTest : public DiskCacheTestWithCache { void BackendDisable3(); void BackendDisable4(); void TracingBackendBasics(); - - bool CreateSetOfRandomEntries(std::set<std::string>* key_pool); - bool EnumerateAndMatchKeys(int max_to_open, - void** iter, - std::set<std::string>* keys_to_match, - size_t* count); }; +int DiskCacheBackendTest::GeneratePendingIO(net::TestCompletionCallback* cb) { + if (!use_current_thread_) { + ADD_FAILURE(); + return net::ERR_FAILED; + } + + disk_cache::Entry* entry; + int rv = cache_->CreateEntry("some key", &entry, cb->callback()); + if (cb->GetResult(rv) != net::OK) + return net::ERR_CACHE_CREATE_FAILURE; + + const int kSize = 25000; + scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer->data(), kSize, false); + + for (int i = 0; i < 10 * 1024 * 1024; i += 64 * 1024) { + // We are using the current thread as the cache thread because we want to + // be able to call directly this method to make sure that the OS (instead + // of us switching thread) is returning IO pending. + if (!simple_cache_mode_) { + rv = static_cast<disk_cache::EntryImpl*>(entry)->WriteDataImpl( + 0, i, buffer.get(), kSize, cb->callback(), false); + } else { + rv = entry->WriteData(0, i, buffer.get(), kSize, cb->callback(), false); + } + + if (rv == net::ERR_IO_PENDING) + break; + if (rv != kSize) + rv = net::ERR_FAILED; + } + + // Don't call Close() to avoid going through the queue or we'll deadlock + // waiting for the operation to finish. + if (!simple_cache_mode_) + static_cast<disk_cache::EntryImpl*>(entry)->Release(); + else + entry->Close(); + + return rv; +} + +void DiskCacheBackendTest::InitSparseCache(base::Time* doomed_start, + base::Time* doomed_end) { + InitCache(); + + const int kSize = 50; + // This must be greater then MemEntryImpl::kMaxSparseEntrySize. + const int kOffset = 10 + 1024 * 1024; + + disk_cache::Entry* entry0 = NULL; + disk_cache::Entry* entry1 = NULL; + disk_cache::Entry* entry2 = NULL; + + scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer->data(), kSize, false); + + ASSERT_EQ(net::OK, CreateEntry("zeroth", &entry0)); + ASSERT_EQ(kSize, WriteSparseData(entry0, 0, buffer.get(), kSize)); + ASSERT_EQ(kSize, + WriteSparseData(entry0, kOffset + kSize, buffer.get(), kSize)); + entry0->Close(); + + FlushQueueForTest(); + AddDelay(); + if (doomed_start) + *doomed_start = base::Time::Now(); + + // Order in rankings list: + // first_part1, first_part2, second_part1, second_part2 + ASSERT_EQ(net::OK, CreateEntry("first", &entry1)); + ASSERT_EQ(kSize, WriteSparseData(entry1, 0, buffer.get(), kSize)); + ASSERT_EQ(kSize, + WriteSparseData(entry1, kOffset + kSize, buffer.get(), kSize)); + entry1->Close(); + + ASSERT_EQ(net::OK, CreateEntry("second", &entry2)); + ASSERT_EQ(kSize, WriteSparseData(entry2, 0, buffer.get(), kSize)); + ASSERT_EQ(kSize, + WriteSparseData(entry2, kOffset + kSize, buffer.get(), kSize)); + entry2->Close(); + + FlushQueueForTest(); + AddDelay(); + if (doomed_end) + *doomed_end = base::Time::Now(); + + // Order in rankings list: + // third_part1, fourth_part1, third_part2, fourth_part2 + disk_cache::Entry* entry3 = NULL; + disk_cache::Entry* entry4 = NULL; + ASSERT_EQ(net::OK, CreateEntry("third", &entry3)); + ASSERT_EQ(kSize, WriteSparseData(entry3, 0, buffer.get(), kSize)); + ASSERT_EQ(net::OK, CreateEntry("fourth", &entry4)); + ASSERT_EQ(kSize, WriteSparseData(entry4, 0, buffer.get(), kSize)); + ASSERT_EQ(kSize, + WriteSparseData(entry3, kOffset + kSize, buffer.get(), kSize)); + ASSERT_EQ(kSize, + WriteSparseData(entry4, kOffset + kSize, buffer.get(), kSize)); + entry3->Close(); + entry4->Close(); + + FlushQueueForTest(); + AddDelay(); +} + +// Creates entries based on random keys. Stores these keys in |key_pool|. +bool DiskCacheBackendTest::CreateSetOfRandomEntries( + std::set<std::string>* key_pool) { + const int kNumEntries = 10; + + for (int i = 0; i < kNumEntries; ++i) { + std::string key = GenerateKey(true); + disk_cache::Entry* entry; + if (CreateEntry(key, &entry) != net::OK) + return false; + key_pool->insert(key); + entry->Close(); + } + return key_pool->size() == implicit_cast<size_t>(cache_->GetEntryCount()); +} + +// Performs iteration over the backend and checks that the keys of entries +// opened are in |keys_to_match|, then erases them. Up to |max_to_open| entries +// will be opened, if it is positive. Otherwise, iteration will continue until +// OpenNextEntry stops returning net::OK. +bool DiskCacheBackendTest::EnumerateAndMatchKeys( + int max_to_open, + void** iter, + std::set<std::string>* keys_to_match, + size_t* count) { + disk_cache::Entry* entry; + + while (OpenNextEntry(iter, &entry) == net::OK) { + if (!entry) + return false; + EXPECT_EQ(1U, keys_to_match->erase(entry->GetKey())); + entry->Close(); + ++(*count); + if (max_to_open >= 0 && implicit_cast<int>(*count) >= max_to_open) + break; + }; + + return true; +} + void DiskCacheBackendTest::BackendBasics() { InitCache(); disk_cache::Entry *entry1 = NULL, *entry2 = NULL; @@ -313,7 +464,7 @@ TEST_F(DiskCacheBackendTest, CreateBackend_MissingFile) { scoped_ptr<disk_cache::BackendImpl> cache(new disk_cache::BackendImpl( cache_path_, cache_thread.message_loop_proxy().get(), NULL)); int rv = cache->Init(cb.callback()); - ASSERT_EQ(net::ERR_FAILED, cb.GetResult(rv)); + EXPECT_EQ(net::ERR_FAILED, cb.GetResult(rv)); base::ThreadRestrictions::SetIOAllowed(prev); cache.reset(); @@ -344,67 +495,32 @@ TEST_F(DiskCacheBackendTest, ExternalFiles) { // Tests that we deal with file-level pending operations at destruction time. void DiskCacheBackendTest::BackendShutdownWithPendingFileIO(bool fast) { - net::TestCompletionCallback cb; - int rv; - - { - ASSERT_TRUE(CleanupCacheDir()); - base::Thread cache_thread("CacheThread"); - ASSERT_TRUE(cache_thread.StartWithOptions( - base::Thread::Options(base::MessageLoop::TYPE_IO, 0))); - - uint32 flags = disk_cache::kNoBuffering; - if (!fast) - flags |= disk_cache::kNoRandom; + ASSERT_TRUE(CleanupCacheDir()); + uint32 flags = disk_cache::kNoBuffering; + if (!fast) + flags |= disk_cache::kNoRandom; - UseCurrentThread(); - CreateBackend(flags, NULL); + UseCurrentThread(); + CreateBackend(flags, NULL); - disk_cache::EntryImpl* entry; - rv = cache_->CreateEntry( - "some key", reinterpret_cast<disk_cache::Entry**>(&entry), - cb.callback()); - ASSERT_EQ(net::OK, cb.GetResult(rv)); - - const int kSize = 25000; - scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kSize)); - CacheTestFillBuffer(buffer->data(), kSize, false); - - for (int i = 0; i < 10 * 1024 * 1024; i += 64 * 1024) { - // We are using the current thread as the cache thread because we want to - // be able to call directly this method to make sure that the OS (instead - // of us switching thread) is returning IO pending. - rv = - entry->WriteDataImpl(0, i, buffer.get(), kSize, cb.callback(), false); - if (rv == net::ERR_IO_PENDING) - break; - EXPECT_EQ(kSize, rv); - } - - // Don't call Close() to avoid going through the queue or we'll deadlock - // waiting for the operation to finish. - entry->Release(); + net::TestCompletionCallback cb; + int rv = GeneratePendingIO(&cb); - // The cache destructor will see one pending operation here. - cache_.reset(); + // The cache destructor will see one pending operation here. + cache_.reset(); - if (rv == net::ERR_IO_PENDING) { - if (fast) - EXPECT_FALSE(cb.have_result()); - else - EXPECT_TRUE(cb.have_result()); - } + if (rv == net::ERR_IO_PENDING) { + if (fast || simple_cache_mode_) + EXPECT_FALSE(cb.have_result()); + else + EXPECT_TRUE(cb.have_result()); } base::MessageLoop::current()->RunUntilIdle(); -#if defined(OS_WIN) // Wait for the actual operation to complete, or we'll keep a file handle that - // may cause issues later. Note that on Posix systems even though this test - // uses a single thread, the actual IO is posted to a worker thread and the - // cache destructor breaks the link to reach cb when the operation completes. + // may cause issues later. rv = cb.GetResult(rv); -#endif } TEST_F(DiskCacheBackendTest, ShutdownWithPendingFileIO) { @@ -427,6 +543,40 @@ TEST_F(DiskCacheBackendTest, ShutdownWithPendingFileIO_Fast) { } #endif +// Tests that one cache instance is not affected by another one going away. +TEST_F(DiskCacheBackendTest, MultipleInstancesWithPendingFileIO) { + base::ScopedTempDir store; + ASSERT_TRUE(store.CreateUniqueTempDir()); + + net::TestCompletionCallback cb; + scoped_ptr<disk_cache::Backend> extra_cache; + int rv = disk_cache::CreateCacheBackend( + net::DISK_CACHE, net::CACHE_BACKEND_DEFAULT, store.path(), 0, + false, base::MessageLoopProxy::current().get(), NULL, + &extra_cache, cb.callback()); + ASSERT_EQ(net::OK, cb.GetResult(rv)); + ASSERT_TRUE(extra_cache.get() != NULL); + + ASSERT_TRUE(CleanupCacheDir()); + SetNewEviction(); // Match the expected behavior for integrity verification. + UseCurrentThread(); + + CreateBackend(disk_cache::kNoBuffering, NULL); + rv = GeneratePendingIO(&cb); + + // cache_ has a pending operation, and extra_cache will go away. + extra_cache.reset(); + + if (rv == net::ERR_IO_PENDING) + EXPECT_FALSE(cb.have_result()); + + base::MessageLoop::current()->RunUntilIdle(); + + // Wait for the actual operation to complete, or we'll keep a file handle that + // may cause issues later. + rv = cb.GetResult(rv); +} + // Tests that we deal with background-thread pending operations. void DiskCacheBackendTest::BackendShutdownWithPendingIO(bool fast) { net::TestCompletionCallback cb; @@ -1395,70 +1545,6 @@ TEST_F(DiskCacheBackendTest, MemoryOnlyDoomRecent) { BackendDoomRecent(); } -void DiskCacheBackendTest::InitSparseCache(base::Time* doomed_start, - base::Time* doomed_end) { - InitCache(); - - const int kSize = 50; - // This must be greater then MemEntryImpl::kMaxSparseEntrySize. - const int kOffset = 10 + 1024 * 1024; - - disk_cache::Entry* entry0 = NULL; - disk_cache::Entry* entry1 = NULL; - disk_cache::Entry* entry2 = NULL; - - scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kSize)); - CacheTestFillBuffer(buffer->data(), kSize, false); - - ASSERT_EQ(net::OK, CreateEntry("zeroth", &entry0)); - ASSERT_EQ(kSize, WriteSparseData(entry0, 0, buffer.get(), kSize)); - ASSERT_EQ(kSize, - WriteSparseData(entry0, kOffset + kSize, buffer.get(), kSize)); - entry0->Close(); - - FlushQueueForTest(); - AddDelay(); - if (doomed_start) - *doomed_start = base::Time::Now(); - - // Order in rankings list: - // first_part1, first_part2, second_part1, second_part2 - ASSERT_EQ(net::OK, CreateEntry("first", &entry1)); - ASSERT_EQ(kSize, WriteSparseData(entry1, 0, buffer.get(), kSize)); - ASSERT_EQ(kSize, - WriteSparseData(entry1, kOffset + kSize, buffer.get(), kSize)); - entry1->Close(); - - ASSERT_EQ(net::OK, CreateEntry("second", &entry2)); - ASSERT_EQ(kSize, WriteSparseData(entry2, 0, buffer.get(), kSize)); - ASSERT_EQ(kSize, - WriteSparseData(entry2, kOffset + kSize, buffer.get(), kSize)); - entry2->Close(); - - FlushQueueForTest(); - AddDelay(); - if (doomed_end) - *doomed_end = base::Time::Now(); - - // Order in rankings list: - // third_part1, fourth_part1, third_part2, fourth_part2 - disk_cache::Entry* entry3 = NULL; - disk_cache::Entry* entry4 = NULL; - ASSERT_EQ(net::OK, CreateEntry("third", &entry3)); - ASSERT_EQ(kSize, WriteSparseData(entry3, 0, buffer.get(), kSize)); - ASSERT_EQ(net::OK, CreateEntry("fourth", &entry4)); - ASSERT_EQ(kSize, WriteSparseData(entry4, 0, buffer.get(), kSize)); - ASSERT_EQ(kSize, - WriteSparseData(entry3, kOffset + kSize, buffer.get(), kSize)); - ASSERT_EQ(kSize, - WriteSparseData(entry4, kOffset + kSize, buffer.get(), kSize)); - entry3->Close(); - entry4->Close(); - - FlushQueueForTest(); - AddDelay(); -} - TEST_F(DiskCacheBackendTest, MemoryOnlyDoomEntriesSinceSparse) { SetMemoryOnlyMode(); base::Time start; @@ -1509,6 +1595,7 @@ void DiskCacheBackendTest::BackendDoomBetween() { AddDelay(); Time middle_end = Time::Now(); + AddDelay(); ASSERT_EQ(net::OK, CreateEntry("fourth", &entry)); entry->Close(); @@ -1792,7 +1879,7 @@ TEST_F(DiskCacheTest, SimpleCacheControlRestart) { net::TestCompletionCallback cb; const int kRestartCount = 5; - for (int i=0; i < kRestartCount; ++i) { + for (int i = 0; i < kRestartCount; ++i) { cache.reset(new disk_cache::BackendImpl( cache_path_, cache_thread.message_loop_proxy(), NULL)); int rv = cache->Init(cb.callback()); @@ -3049,9 +3136,22 @@ TEST_F(DiskCacheBackendTest, TracingBackendBasics) { TracingBackendBasics(); } -// The simple cache backend isn't intended to work on windows, which has very -// different file system guarantees from Windows. -#if !defined(OS_WIN) +// The Simple Cache backend requires a few guarantees from the filesystem like +// atomic renaming of recently open files. Those guarantees are not provided in +// general on Windows. +#if defined(OS_POSIX) + +TEST_F(DiskCacheBackendTest, SimpleCacheShutdownWithPendingCreate) { + SetCacheType(net::APP_CACHE); + SetSimpleCacheMode(); + BackendShutdownWithPendingCreate(false); +} + +TEST_F(DiskCacheBackendTest, SimpleCacheShutdownWithPendingFileIO) { + SetCacheType(net::APP_CACHE); + SetSimpleCacheMode(); + BackendShutdownWithPendingFileIO(false); +} TEST_F(DiskCacheBackendTest, SimpleCacheBasics) { SetSimpleCacheMode(); @@ -3111,13 +3211,12 @@ TEST_F(DiskCacheBackendTest, SimpleDoomBetween) { BackendDoomBetween(); } -// See http://crbug.com/237450. -TEST_F(DiskCacheBackendTest, FLAKY_SimpleCacheDoomAll) { +TEST_F(DiskCacheBackendTest, SimpleCacheDoomAll) { SetSimpleCacheMode(); BackendDoomAll(); } -TEST_F(DiskCacheBackendTest, FLAKY_SimpleCacheAppCacheOnlyDoomAll) { +TEST_F(DiskCacheBackendTest, SimpleCacheAppCacheOnlyDoomAll) { SetCacheType(net::APP_CACHE); SetSimpleCacheMode(); BackendDoomAll(); @@ -3151,7 +3250,7 @@ TEST_F(DiskCacheBackendTest, SimpleCacheOpenMissingFile) { // Delete one of the files in the entry. base::FilePath to_delete_file = cache_path_.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, 0)); + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, 0)); EXPECT_TRUE(base::PathExists(to_delete_file)); EXPECT_TRUE(disk_cache::DeleteCacheFile(to_delete_file)); @@ -3160,9 +3259,8 @@ TEST_F(DiskCacheBackendTest, SimpleCacheOpenMissingFile) { // Confirm the rest of the files are gone. for (int i = 1; i < disk_cache::kSimpleEntryFileCount; ++i) { - base::FilePath - should_be_gone_file(cache_path_.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, i))); + base::FilePath should_be_gone_file(cache_path_.AppendASCII( + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, i))); EXPECT_FALSE(base::PathExists(should_be_gone_file)); } } @@ -3187,9 +3285,9 @@ TEST_F(DiskCacheBackendTest, SimpleCacheOpenBadFile) { entry->Close(); entry = NULL; - // Write an invalid header on stream 1. + // Write an invalid header for stream 0 and stream 1. base::FilePath entry_file1_path = cache_path_.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, 1)); + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, 0)); disk_cache::SimpleFileHeader header; header.initial_magic_number = GG_UINT64_C(0xbadf00d); @@ -3265,46 +3363,6 @@ TEST_F(DiskCacheBackendTest, SimpleCacheFixEnumerators) { BackendFixEnumerators(); } -// Creates entries based on random keys. Stores these keys in |key_pool|. -bool DiskCacheBackendTest::CreateSetOfRandomEntries( - std::set<std::string>* key_pool) { - const int kNumEntries = 10; - - for (int i = 0; i < kNumEntries; ++i) { - std::string key = GenerateKey(true); - disk_cache::Entry* entry; - if (CreateEntry(key, &entry) != net::OK) - return false; - key_pool->insert(key); - entry->Close(); - } - return key_pool->size() == implicit_cast<size_t>(cache_->GetEntryCount()); -} - -// Performs iteration over the backend and checks that the keys of entries -// opened are in |keys_to_match|, then erases them. Up to |max_to_open| entries -// will be opened, if it is positive. Otherwise, iteration will continue until -// OpenNextEntry stops returning net::OK. -bool DiskCacheBackendTest::EnumerateAndMatchKeys( - int max_to_open, - void** iter, - std::set<std::string>* keys_to_match, - size_t* count) { - disk_cache::Entry* entry; - - while (OpenNextEntry(iter, &entry) == net::OK) { - if (!entry) - return false; - EXPECT_EQ(1U, keys_to_match->erase(entry->GetKey())); - entry->Close(); - ++(*count); - if (max_to_open >= 0 && implicit_cast<int>(*count) >= max_to_open) - break; - }; - - return true; -} - // Tests basic functionality of the SimpleBackend implementation of the // enumeration API. TEST_F(DiskCacheBackendTest, SimpleCacheEnumerationBasics) { @@ -3412,4 +3470,8 @@ TEST_F(DiskCacheBackendTest, SimpleCacheEnumerationCorruption) { EXPECT_TRUE(keys_to_match.empty()); } -#endif // !defined(OS_WIN) +// TODO(pasko): Add a Simple Cache test that would simulate upgrade from the +// version with the index file in the cache directory to the version with the +// index file in subdirectory. + +#endif // defined(OS_POSIX) diff --git a/chromium/net/disk_cache/block_files.cc b/chromium/net/disk_cache/block_files.cc index fc378e6c449..896cdb16328 100644 --- a/chromium/net/disk_cache/block_files.cc +++ b/chromium/net/disk_cache/block_files.cc @@ -52,9 +52,17 @@ BlockHeader::BlockHeader(const BlockHeader& other) : header_(other.header_) { BlockHeader::~BlockHeader() { } -bool BlockHeader::CreateMapBlock(int target, int size, int* index) { - if (target <= 0 || target > kMaxNumBlocks || - size <= 0 || size > kMaxNumBlocks) { +bool BlockHeader::CreateMapBlock(int size, int* index) { + DCHECK(size > 0 && size <= kMaxNumBlocks); + int target = 0; + for (int i = size; i <= kMaxNumBlocks; i++) { + if (header_->empty[i - 1]) { + target = i; + break; + } + } + + if (!target) { NOTREACHED(); return false; } @@ -144,10 +152,9 @@ void BlockHeader::DeleteMapBlock(int index, int size) { // Note that this is a simplified version of DeleteMapBlock(). bool BlockHeader::UsedMapBlock(int index, int size) { - if (size < 0 || size > kMaxNumBlocks) { - NOTREACHED(); + if (size < 0 || size > kMaxNumBlocks) return false; - } + int byte_index = index / 8; uint8* byte_map = reinterpret_cast<uint8*>(header_->allocation_map); uint8 map_block = byte_map[byte_index]; @@ -177,7 +184,7 @@ void BlockHeader::FixAllocationCounters() { } } -bool BlockHeader::NeedToGrowBlockFile(int block_count) { +bool BlockHeader::NeedToGrowBlockFile(int block_count) const { bool have_space = false; int empty_blocks = 0; for (int i = 0; i < kMaxNumBlocks; i++) { @@ -195,9 +202,19 @@ bool BlockHeader::NeedToGrowBlockFile(int block_count) { return !have_space; } +bool BlockHeader::CanAllocate(int block_count) const { + DCHECK_GT(block_count, 0); + for (int i = block_count - 1; i < kMaxNumBlocks; i++) { + if (header_->empty[i]) + return true; + } + + return false; +} + int BlockHeader::EmptyBlocks() const { int empty_blocks = 0; - for (int i = 0; i < disk_cache::kMaxNumBlocks; i++) { + for (int i = 0; i < kMaxNumBlocks; i++) { empty_blocks += header_->empty[i] * (i + 1); if (header_->empty[i] < 0) return 0; @@ -205,6 +222,14 @@ int BlockHeader::EmptyBlocks() const { return empty_blocks; } +int BlockHeader::MinimumAllocations() const { + return header_->empty[kMaxNumBlocks - 1]; +} + +int BlockHeader::Capacity() const { + return header_->max_entries; +} + bool BlockHeader::ValidateCounters() const { if (header_->max_entries < 0 || header_->max_entries > kMaxBlocks || header_->num_entries < 0) @@ -217,10 +242,22 @@ bool BlockHeader::ValidateCounters() const { return true; } +int BlockHeader::FileId() const { + return header_->this_file; +} + +int BlockHeader::NextFileId() const { + return header_->next_file; +} + int BlockHeader::Size() const { return static_cast<int>(sizeof(*header_)); } +BlockFileHeader* BlockHeader::Header() { + return header_; +} + // ------------------------------------------------------------------------ BlockFiles::BlockFiles(const base::FilePath& path) @@ -260,7 +297,8 @@ bool BlockFiles::Init(bool create_files) { MappedFile* BlockFiles::GetFile(Addr address) { DCHECK(thread_checker_->CalledOnValidThread()); - DCHECK(block_files_.size() >= 4); + DCHECK_GE(block_files_.size(), + static_cast<size_t>(kFirstAdditionalBlockFile)); DCHECK(address.is_block_file() || !address.is_initialized()); if (!address.is_initialized()) return NULL; @@ -272,16 +310,20 @@ MappedFile* BlockFiles::GetFile(Addr address) { if (!OpenBlockFile(file_index)) return NULL; } - DCHECK(block_files_.size() >= static_cast<unsigned int>(file_index)); + DCHECK_GE(block_files_.size(), static_cast<unsigned int>(file_index)); return block_files_[file_index]; } bool BlockFiles::CreateBlock(FileType block_type, int block_count, Addr* block_address) { DCHECK(thread_checker_->CalledOnValidThread()); - if (block_type < RANKINGS || block_type > BLOCK_4K || - block_count < 1 || block_count > 4) + DCHECK_NE(block_type, EXTERNAL); + DCHECK_NE(block_type, BLOCK_FILES); + DCHECK_NE(block_type, BLOCK_ENTRIES); + DCHECK_NE(block_type, BLOCK_EVICTED); + if (block_count < 1 || block_count > kMaxNumBlocks) return false; + if (!init_) return false; @@ -290,22 +332,13 @@ bool BlockFiles::CreateBlock(FileType block_type, int block_count, return false; ScopedFlush flush(file); - BlockHeader header(file); + BlockHeader file_header(file); - int target_size = 0; - for (int i = block_count; i <= 4; i++) { - if (header->empty[i - 1]) { - target_size = i; - break; - } - } - - DCHECK(target_size); int index; - if (!header.CreateMapBlock(target_size, block_count, &index)) + if (!file_header.CreateMapBlock(block_count, &index)) return false; - Addr address(block_type, block_count, header->this_file, index); + Addr address(block_type, block_count, file_header.FileId(), index); block_address->set_value(address.value()); Trace("CreateBlock 0x%x", address.value()); return true; @@ -332,15 +365,17 @@ void BlockFiles::DeleteBlock(Addr address, bool deep) { if (deep) file->Write(zero_buffer_, size, offset); - BlockHeader header(file); - header.DeleteMapBlock(address.start_block(), address.num_blocks()); + BlockHeader file_header(file); + file_header.DeleteMapBlock(address.start_block(), address.num_blocks()); file->Flush(); - if (!header->num_entries) { + if (!file_header.Header()->num_entries) { // This file is now empty. Let's try to delete it. - FileType type = Addr::RequiredFileType(header->entry_size); - if (Addr::BlockSizeForFileType(RANKINGS) == header->entry_size) + FileType type = Addr::RequiredFileType(file_header.Header()->entry_size); + if (Addr::BlockSizeForFileType(RANKINGS) == + file_header.Header()->entry_size) { type = RANKINGS; + } RemoveEmptyFile(type); // Ignore failures. } } @@ -450,13 +485,14 @@ bool BlockFiles::OpenBlockFile(int index) { return false; } - BlockHeader header(file.get()); + BlockHeader file_header(file.get()); + BlockFileHeader* header = file_header.Header(); if (kBlockMagic != header->magic || kBlockVersion2 != header->version) { LOG(ERROR) << "Invalid file version or magic " << name.value(); return false; } - if (header->updating || !header.ValidateCounters()) { + if (header->updating || !file_header.ValidateCounters()) { // Last instance was not properly shutdown, or counters are out of sync. if (!FixBlockFileHeader(file.get())) { LOG(ERROR) << "Unable to fix block file " << name.value(); @@ -516,19 +552,19 @@ bool BlockFiles::GrowBlockFile(MappedFile* file, BlockFileHeader* header) { MappedFile* BlockFiles::FileForNewBlock(FileType block_type, int block_count) { COMPILE_ASSERT(RANKINGS == 1, invalid_file_type); MappedFile* file = block_files_[block_type - 1]; - BlockHeader header(file); + BlockHeader file_header(file); TimeTicks start = TimeTicks::Now(); - while (header.NeedToGrowBlockFile(block_count)) { - if (kMaxBlocks == header->max_entries) { + while (file_header.NeedToGrowBlockFile(block_count)) { + if (kMaxBlocks == file_header.Header()->max_entries) { file = NextFile(file); if (!file) return NULL; - header = BlockHeader(file); + file_header = BlockHeader(file); continue; } - if (!GrowBlockFile(file, header.Get())) + if (!GrowBlockFile(file, file_header.Header())) return NULL; break; } @@ -616,38 +652,39 @@ bool BlockFiles::RemoveEmptyFile(FileType block_type) { // DCHECK on header->updating because we may be fixing a crash. bool BlockFiles::FixBlockFileHeader(MappedFile* file) { ScopedFlush flush(file); - BlockHeader header(file); + BlockHeader file_header(file); int file_size = static_cast<int>(file->GetLength()); - if (file_size < header.Size()) + if (file_size < file_header.Size()) return false; // file_size > 2GB is also an error. const int kMinBlockSize = 36; const int kMaxBlockSize = 4096; + BlockFileHeader* header = file_header.Header(); if (header->entry_size < kMinBlockSize || header->entry_size > kMaxBlockSize || header->num_entries < 0) return false; // Make sure that we survive crashes. header->updating = 1; - int expected = header->entry_size * header->max_entries + header.Size(); + int expected = header->entry_size * header->max_entries + file_header.Size(); if (file_size != expected) { - int max_expected = header->entry_size * kMaxBlocks + header.Size(); + int max_expected = header->entry_size * kMaxBlocks + file_header.Size(); if (file_size < expected || header->empty[3] || file_size > max_expected) { NOTREACHED(); LOG(ERROR) << "Unexpected file size"; return false; } // We were in the middle of growing the file. - int num_entries = (file_size - header.Size()) / header->entry_size; + int num_entries = (file_size - file_header.Size()) / header->entry_size; header->max_entries = num_entries; } - header.FixAllocationCounters(); - int empty_blocks = header.EmptyBlocks(); + file_header.FixAllocationCounters(); + int empty_blocks = file_header.EmptyBlocks(); if (empty_blocks + header->num_entries > header->max_entries) header->num_entries = header->max_entries - empty_blocks; - if (!header.ValidateCounters()) + if (!file_header.ValidateCounters()) return false; header->updating = 0; @@ -671,7 +708,7 @@ void BlockFiles::GetFileStats(int index, int* used_count, int* load) { max_blocks += header->max_entries; int used = header->max_entries; - for (int i = 0; i < 4; i++) { + for (int i = 0; i < kMaxNumBlocks; i++) { used -= header->empty[i] * (i + 1); DCHECK_GE(used, 0); } @@ -687,7 +724,7 @@ void BlockFiles::GetFileStats(int index, int* used_count, int* load) { base::FilePath BlockFiles::Name(int index) { // The file format allows for 256 files. - DCHECK(index < 256 || index >= 0); + DCHECK(index < 256 && index >= 0); std::string tmp = base::StringPrintf("%s%d", kBlockName, index); return path_.AppendASCII(tmp); } diff --git a/chromium/net/disk_cache/block_files.h b/chromium/net/disk_cache/block_files.h index 353c5663df0..f8d5483a0b3 100644 --- a/chromium/net/disk_cache/block_files.h +++ b/chromium/net/disk_cache/block_files.h @@ -24,7 +24,11 @@ class ThreadChecker; namespace disk_cache { // An instance of this class represents the header of a block file in memory. -// Note that this class doesn't perform any file operation. +// Note that this class doesn't perform any file operation (as in it only deals +// with entities in memory). +// The header of a block file (and hence, this object) is all that is needed to +// perform common operations like allocating or releasing space for storage; +// actual access to that storage, however, is not performed through this class. class NET_EXPORT_PRIVATE BlockHeader { public: BlockHeader(); @@ -33,10 +37,9 @@ class NET_EXPORT_PRIVATE BlockHeader { BlockHeader(const BlockHeader& other); ~BlockHeader(); - // Creates a new entry on the allocation map, updating the apropriate - // counters. |target| is the type of block to use (number of empty blocks), - // and |size| is the actual number of blocks to use. - bool CreateMapBlock(int target, int size, int* index); + // Creates a new entry of |size| blocks on the allocation map, updating the + // apropriate counters. + bool CreateMapBlock(int size, int* index); // Deletes the block pointed by |index|. void DeleteMapBlock(int index, int block_size); @@ -49,20 +52,34 @@ class NET_EXPORT_PRIVATE BlockHeader { // Returns true if the current block file should not be used as-is to store // more records. |block_count| is the number of blocks to allocate. - bool NeedToGrowBlockFile(int block_count); + bool NeedToGrowBlockFile(int block_count) const; + + // Returns true if this block file can be used to store an extra record of + // size |block_count|. + bool CanAllocate(int block_count) const; // Returns the number of empty blocks for this file. int EmptyBlocks() const; + // Returns the minumum number of allocations that can be satisfied. + int MinimumAllocations() const; + + // Returns the number of blocks that this file can store. + int Capacity() const; + // Returns true if the counters look OK. bool ValidateCounters() const; + // Returns the identifiers of this and the next file (0 if there is none). + int FileId() const; + int NextFileId() const; + // Returns the size of the wrapped structure (BlockFileHeader). int Size() const; - BlockFileHeader* operator->() { return header_; } - void operator=(const BlockHeader& other) { header_ = other.header_; } - BlockFileHeader* Get() { return header_; } + // Returns a pointer to the underlying BlockFileHeader. + // TODO(rvargas): This may be removed with the support for V2. + BlockFileHeader* Header(); private: BlockFileHeader* header_; diff --git a/chromium/net/disk_cache/cache_creator.cc b/chromium/net/disk_cache/cache_creator.cc index 07a26c9d858..857d1714f7c 100644 --- a/chromium/net/disk_cache/cache_creator.cc +++ b/chromium/net/disk_cache/cache_creator.cc @@ -79,10 +79,10 @@ CacheCreator::~CacheCreator() { } int CacheCreator::Run() { - // TODO(gavinp,pasko): While simple backend development proceeds, we're only - // testing it against net::DISK_CACHE. Turn it on for more cache types as + // TODO(gavinp,pasko): Turn Simple Cache on for more cache types as // appropriate. - if (backend_type_ == net::CACHE_BACKEND_SIMPLE && type_ == net::DISK_CACHE) { + if (backend_type_ == net::CACHE_BACKEND_SIMPLE && + (type_ == net::DISK_CACHE || type_ == net::APP_CACHE)) { disk_cache::SimpleBackendImpl* simple_cache = new disk_cache::SimpleBackendImpl(path_, max_bytes_, type_, thread_.get(), net_log_); diff --git a/chromium/net/disk_cache/disk_cache_perftest.cc b/chromium/net/disk_cache/disk_cache_perftest.cc index f7a1b5969c6..6adc4bca709 100644 --- a/chromium/net/disk_cache/disk_cache_perftest.cc +++ b/chromium/net/disk_cache/disk_cache_perftest.cc @@ -8,8 +8,8 @@ #include "base/bind.h" #include "base/bind_helpers.h" #include "base/hash.h" -#include "base/perftimer.h" #include "base/strings/string_util.h" +#include "base/test/perf_time_logger.h" #include "base/test/test_file_util.h" #include "base/threading/thread.h" #include "base/timer/timer.h" @@ -53,7 +53,7 @@ bool TimeWrite(int num_entries, disk_cache::Backend* cache, MessageLoopHelper helper; CallbackTest callback(&helper, true); - PerfTimeLogger timer("Write disk cache entries"); + base::PerfTimeLogger timer("Write disk cache entries"); for (int i = 0; i < num_entries; i++) { TestEntry entry; @@ -107,7 +107,7 @@ bool TimeRead(int num_entries, disk_cache::Backend* cache, const char* message = cold ? "Read disk cache entries (cold)" : "Read disk cache entries (warm)"; - PerfTimeLogger timer(message); + base::PerfTimeLogger timer(message); for (int i = 0; i < num_entries; i++) { disk_cache::Entry* cache_entry; @@ -150,7 +150,7 @@ TEST_F(DiskCacheTest, Hash) { int seed = static_cast<int>(Time::Now().ToInternalValue()); srand(seed); - PerfTimeLogger timer("Hash disk cache keys"); + base::PerfTimeLogger timer("Hash disk cache keys"); for (int i = 0; i < 300000; i++) { std::string key = GenerateKey(true); base::Hash(key); @@ -223,7 +223,7 @@ TEST_F(DiskCacheTest, BlockFilesPerformance) { const int kNumEntries = 60000; disk_cache::Addr* address = new disk_cache::Addr[kNumEntries]; - PerfTimeLogger timer1("Fill three block-files"); + base::PerfTimeLogger timer1("Fill three block-files"); // Fill up the 32-byte block file (use three files). for (int i = 0; i < kNumEntries; i++) { @@ -232,7 +232,7 @@ TEST_F(DiskCacheTest, BlockFilesPerformance) { } timer1.Done(); - PerfTimeLogger timer2("Create and delete blocks"); + base::PerfTimeLogger timer2("Create and delete blocks"); for (int i = 0; i < 200000; i++) { int entry = rand() * (kNumEntries / RAND_MAX + 1); diff --git a/chromium/net/disk_cache/disk_cache_test_base.cc b/chromium/net/disk_cache/disk_cache_test_base.cc index dc7bb6c5fb2..3cf98a86b14 100644 --- a/chromium/net/disk_cache/disk_cache_test_base.cc +++ b/chromium/net/disk_cache/disk_cache_test_base.cc @@ -7,6 +7,7 @@ #include "base/file_util.h" #include "base/path_service.h" #include "base/run_loop.h" +#include "base/threading/platform_thread.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -221,6 +222,17 @@ void DiskCacheTestWithCache::TrimDeletedListForTest(bool empty) { } void DiskCacheTestWithCache::AddDelay() { + if (simple_cache_mode_) { + // The simple cache uses second resolution for many timeouts, so it's safest + // to advance by at least whole seconds before falling back into the normal + // disk cache epsilon advance. + const base::Time initial_time = base::Time::Now(); + do { + base::PlatformThread::YieldCurrentThread(); + } while (base::Time::Now() - + initial_time < base::TimeDelta::FromSeconds(1)); + } + base::Time initial = base::Time::Now(); while (base::Time::Now() <= initial) { base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(1)); @@ -229,6 +241,8 @@ void DiskCacheTestWithCache::AddDelay() { void DiskCacheTestWithCache::TearDown() { base::RunLoop().RunUntilIdle(); + disk_cache::SimpleBackendImpl::FlushWorkerPoolForTesting(); + base::RunLoop().RunUntilIdle(); cache_.reset(); if (cache_thread_.IsRunning()) cache_thread_.Stop(); @@ -236,8 +250,9 @@ void DiskCacheTestWithCache::TearDown() { if (!memory_only_ && !simple_cache_mode_ && integrity_) { EXPECT_TRUE(CheckCacheIntegrity(cache_path_, new_eviction_, mask_)); } - - PlatformTest::TearDown(); + base::RunLoop().RunUntilIdle(); + disk_cache::SimpleBackendImpl::FlushWorkerPoolForTesting(); + DiskCacheTest::TearDown(); } void DiskCacheTestWithCache::InitMemoryCache() { diff --git a/chromium/net/disk_cache/disk_format_base.h b/chromium/net/disk_cache/disk_format_base.h index c8b7490abfd..31983817fcf 100644 --- a/chromium/net/disk_cache/disk_format_base.h +++ b/chromium/net/disk_cache/disk_format_base.h @@ -28,6 +28,7 @@ namespace disk_cache { typedef uint32 CacheAddr; const uint32 kBlockVersion2 = 0x20000; // Version 2.0. +const uint32 kBlockCurrentVersion = 0x30000; // Version 3.0. const uint32 kBlockMagic = 0xC104CAC3; const int kBlockHeaderSize = 8192; // Two pages: almost 64k entries diff --git a/chromium/net/disk_cache/entry_unittest.cc b/chromium/net/disk_cache/entry_unittest.cc index 857e07f1b7b..b8c9e668198 100644 --- a/chromium/net/disk_cache/entry_unittest.cc +++ b/chromium/net/disk_cache/entry_unittest.cc @@ -21,6 +21,7 @@ #include "net/disk_cache/mem_entry_impl.h" #include "net/disk_cache/simple/simple_entry_format.h" #include "net/disk_cache/simple/simple_entry_impl.h" +#include "net/disk_cache/simple/simple_synchronous_entry.h" #include "net/disk_cache/simple/simple_test_util.h" #include "net/disk_cache/simple/simple_util.h" #include "testing/gtest/include/gtest/gtest.h" @@ -62,7 +63,9 @@ class DiskCacheEntryTest : public DiskCacheTestWithCache { void UpdateSparseEntry(); void DoomSparseEntry(); void PartialSparseEntry(); - bool SimpleCacheMakeBadChecksumEntry(const char* key, int* data_size); + bool SimpleCacheMakeBadChecksumEntry(const std::string& key, int* data_size); + bool SimpleCacheThirdStreamFileExists(const char* key); + void SyncDoomEntry(const char* key); }; // This part of the test runs on the background thread. @@ -392,7 +395,6 @@ void DiskCacheEntryTest::ExternalSyncIOBackground(disk_cache::Entry* entry) { EXPECT_EQ( 25000, entry->ReadData(1, 0, buffer2.get(), kSize2, net::CompletionCallback())); - EXPECT_EQ(0, memcmp(buffer2->data(), buffer2->data(), 10000)); EXPECT_EQ(5000, entry->ReadData( 1, 30000, buffer2.get(), kSize2, net::CompletionCallback())); @@ -2298,8 +2300,9 @@ TEST_F(DiskCacheEntryTest, KeySanityCheck) { DisableIntegrityCheck(); } -// The simple cache backend isn't intended to work on Windows, which has very -// different file system guarantees from Linux. +// The Simple Cache backend requires a few guarantees from the filesystem like +// atomic renaming of recently open files. Those guarantees are not provided in +// general on Windows. #if defined(OS_POSIX) TEST_F(DiskCacheEntryTest, SimpleCacheInternalAsyncIO) { @@ -2414,7 +2417,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheDoomedEntry) { // Creates an entry with corrupted last byte in stream 0. // Requires SimpleCacheMode. -bool DiskCacheEntryTest::SimpleCacheMakeBadChecksumEntry(const char* key, +bool DiskCacheEntryTest::SimpleCacheMakeBadChecksumEntry(const std::string& key, int* data_size) { disk_cache::Entry* entry = NULL; @@ -2428,21 +2431,21 @@ bool DiskCacheEntryTest::SimpleCacheMakeBadChecksumEntry(const char* key, scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kDataSize)); base::strlcpy(buffer->data(), data, kDataSize); - EXPECT_EQ(kDataSize, WriteData(entry, 0, 0, buffer.get(), kDataSize, false)); + EXPECT_EQ(kDataSize, WriteData(entry, 1, 0, buffer.get(), kDataSize, false)); entry->Close(); entry = NULL; // Corrupt the last byte of the data. base::FilePath entry_file0_path = cache_path_.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, 0)); + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, 0)); int flags = base::PLATFORM_FILE_WRITE | base::PLATFORM_FILE_OPEN; base::PlatformFile entry_file0 = base::CreatePlatformFile(entry_file0_path, flags, NULL, NULL); if (entry_file0 == base::kInvalidPlatformFileValue) return false; + int64 file_offset = - disk_cache::simple_util::GetFileOffsetFromKeyAndDataOffset( - key, kDataSize - 2); + sizeof(disk_cache::SimpleFileHeader) + key.size() + kDataSize - 2; EXPECT_EQ(1, base::WritePlatformFile(entry_file0, file_offset, "X", 1)); if (!base::ClosePlatformFile(entry_file0)) return false; @@ -2466,10 +2469,10 @@ TEST_F(DiskCacheEntryTest, SimpleCacheBadChecksum) { ScopedEntryPtr entry_closer(entry); const int kReadBufferSize = 200; - EXPECT_GE(kReadBufferSize, entry->GetDataSize(0)); + EXPECT_GE(kReadBufferSize, entry->GetDataSize(1)); scoped_refptr<net::IOBuffer> read_buffer(new net::IOBuffer(kReadBufferSize)); EXPECT_EQ(net::ERR_CACHE_CHECKSUM_MISMATCH, - ReadData(entry, 0, 0, read_buffer.get(), kReadBufferSize)); + ReadData(entry, 1, 0, read_buffer.get(), kReadBufferSize)); } // Tests that an entry that has had an IO error occur can still be Doomed(). @@ -2488,10 +2491,10 @@ TEST_F(DiskCacheEntryTest, SimpleCacheErrorThenDoom) { ScopedEntryPtr entry_closer(entry); const int kReadBufferSize = 200; - EXPECT_GE(kReadBufferSize, entry->GetDataSize(0)); + EXPECT_GE(kReadBufferSize, entry->GetDataSize(1)); scoped_refptr<net::IOBuffer> read_buffer(new net::IOBuffer(kReadBufferSize)); EXPECT_EQ(net::ERR_CACHE_CHECKSUM_MISMATCH, - ReadData(entry, 0, 0, read_buffer.get(), kReadBufferSize)); + ReadData(entry, 1, 0, read_buffer.get(), kReadBufferSize)); entry->Doom(); // Should not crash. } @@ -2530,7 +2533,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheNoEOF) { // record. int kTruncationBytes = -implicit_cast<int>(sizeof(disk_cache::SimpleFileEOF)); const base::FilePath entry_path = cache_path_.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, 0)); + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, 0)); const int64 invalid_size = disk_cache::simple_util::GetFileSizeFromKeyAndDataSize(key, kTruncationBytes); @@ -2558,13 +2561,12 @@ TEST_F(DiskCacheEntryTest, SimpleCacheNonOptimisticOperationsBasic) { CacheTestFillBuffer(write_buffer->data(), write_buffer->size(), false); EXPECT_EQ( write_buffer->size(), - WriteData(entry, 0, 0, write_buffer.get(), write_buffer->size(), false)); + WriteData(entry, 1, 0, write_buffer.get(), write_buffer->size(), false)); scoped_refptr<net::IOBufferWithSize> read_buffer( new net::IOBufferWithSize(kBufferSize)); - EXPECT_EQ( - read_buffer->size(), - ReadData(entry, 0, 0, read_buffer.get(), read_buffer->size())); + EXPECT_EQ(read_buffer->size(), + ReadData(entry, 1, 0, read_buffer.get(), read_buffer->size())); } TEST_F(DiskCacheEntryTest, SimpleCacheNonOptimisticOperationsDontBlock) { @@ -2591,7 +2593,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheNonOptimisticOperationsDontBlock) { CacheTestFillBuffer(write_buffer->data(), write_buffer->size(), false); CallbackTest write_callback(&helper, false); int ret = entry->WriteData( - 0, + 1, 0, write_buffer.get(), write_buffer->size(), @@ -2624,7 +2626,7 @@ TEST_F(DiskCacheEntryTest, CacheTestFillBuffer(write_buffer->data(), write_buffer->size(), false); CallbackTest write_callback(&helper, false); int ret = entry->WriteData( - 0, + 1, 0, write_buffer.get(), write_buffer->size(), @@ -2637,7 +2639,7 @@ TEST_F(DiskCacheEntryTest, new net::IOBufferWithSize(kBufferSize)); CallbackTest read_callback(&helper, false); ret = entry->ReadData( - 0, + 1, 0, read_buffer.get(), read_buffer->size(), @@ -2689,7 +2691,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic) { // This write may or may not be optimistic (it depends if the previous // optimistic create already finished by the time we call the write here). int ret = entry->WriteData( - 0, + 1, 0, buffer1.get(), kSize1, @@ -2702,7 +2704,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic) { // This Read must not be optimistic, since we don't support that yet. EXPECT_EQ(net::ERR_IO_PENDING, entry->ReadData( - 0, + 1, 0, buffer1_read.get(), kSize1, @@ -2715,7 +2717,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic) { // should be empty, so the next Write operation must run as optimistic. EXPECT_EQ(kSize2, entry->WriteData( - 0, + 1, 0, buffer2.get(), kSize2, @@ -2726,7 +2728,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic) { // operation finishes and we can then test for HasOneRef() below. EXPECT_EQ(net::ERR_IO_PENDING, entry->ReadData( - 0, + 1, 0, buffer2_read.get(), kSize2, @@ -2832,7 +2834,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic4) { // operation finishes. Write must fail since we are writing in a closed entry. EXPECT_EQ( net::ERR_IO_PENDING, - entry->WriteData(0, 0, buffer1.get(), kSize1, cb.callback(), false)); + entry->WriteData(1, 0, buffer1.get(), kSize1, cb.callback(), false)); EXPECT_EQ(net::ERR_FAILED, cb.GetResult(net::ERR_IO_PENDING)); // Finish running the pending tasks so that we fully complete the close @@ -2861,12 +2863,12 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic4) { // entry. EXPECT_EQ(kSize1, entry2->WriteData( - 0, 0, buffer1.get(), kSize1, net::CompletionCallback(), false)); + 1, 0, buffer1.get(), kSize1, net::CompletionCallback(), false)); // Lets do another read so we block until both the write and the read // operation finishes and we can then test for HasOneRef() below. EXPECT_EQ(net::ERR_IO_PENDING, - entry2->ReadData(0, 0, buffer1.get(), kSize1, cb.callback())); + entry2->ReadData(1, 0, buffer1.get(), kSize1, cb.callback())); EXPECT_EQ(kSize1, cb.GetResult(net::ERR_IO_PENDING)); // Check that we are not leaking. @@ -2875,9 +2877,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic4) { entry2->Close(); } -// This test is flaky because of the race of Create followed by a Doom. -// See test SimpleCacheCreateDoomRace. -TEST_F(DiskCacheEntryTest, DISABLED_SimpleCacheOptimistic5) { +TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic5) { // Test sequence: // Create, Doom, Write, Read, Close. SetSimpleCacheMode(); @@ -2899,11 +2899,11 @@ TEST_F(DiskCacheEntryTest, DISABLED_SimpleCacheOptimistic5) { EXPECT_EQ( net::ERR_IO_PENDING, - entry->WriteData(0, 0, buffer1.get(), kSize1, cb.callback(), false)); + entry->WriteData(1, 0, buffer1.get(), kSize1, cb.callback(), false)); EXPECT_EQ(kSize1, cb.GetResult(net::ERR_IO_PENDING)); EXPECT_EQ(net::ERR_IO_PENDING, - entry->ReadData(0, 0, buffer1.get(), kSize1, cb.callback())); + entry->ReadData(1, 0, buffer1.get(), kSize1, cb.callback())); EXPECT_EQ(kSize1, cb.GetResult(net::ERR_IO_PENDING)); // Check that we are not leaking. @@ -2933,7 +2933,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic6) { EXPECT_EQ( net::ERR_IO_PENDING, - entry->WriteData(0, 0, buffer1.get(), kSize1, cb.callback(), false)); + entry->WriteData(1, 0, buffer1.get(), kSize1, cb.callback(), false)); EXPECT_EQ(kSize1, cb.GetResult(net::ERR_IO_PENDING)); entry->Doom(); @@ -2941,15 +2941,11 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimistic6) { // This Read must not be optimistic, since we don't support that yet. EXPECT_EQ(net::ERR_IO_PENDING, - entry->ReadData(0, 0, buffer1_read.get(), kSize1, cb.callback())); + entry->ReadData(1, 0, buffer1_read.get(), kSize1, cb.callback())); EXPECT_EQ(kSize1, cb.GetResult(net::ERR_IO_PENDING)); EXPECT_EQ(0, memcmp(buffer1->data(), buffer1_read->data(), kSize1)); entry->Doom(); - - // Check that we are not leaking. - EXPECT_TRUE( - static_cast<disk_cache::SimpleEntryImpl*>(entry)->HasOneRef()); } // Confirm that IO buffers are not referenced by the Simple Cache after a write @@ -2976,7 +2972,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimisticWriteReleases) { // operations. To ensure the queue is empty, we issue a write and wait until // it completes. EXPECT_EQ(kWriteSize, - WriteData(entry, 0, 0, buffer1.get(), kWriteSize, false)); + WriteData(entry, 1, 0, buffer1.get(), kWriteSize, false)); EXPECT_TRUE(buffer1->HasOneRef()); // Finally, we should perform an optimistic write and confirm that all @@ -2988,7 +2984,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheOptimisticWriteReleases) { EXPECT_TRUE(buffer1->HasOneRef()); } -TEST_F(DiskCacheEntryTest, DISABLED_SimpleCacheCreateDoomRace) { +TEST_F(DiskCacheEntryTest, SimpleCacheCreateDoomRace) { // Test sequence: // Create, Doom, Write, Close, Check files are not on disk anymore. SetSimpleCacheMode(); @@ -3006,20 +3002,13 @@ TEST_F(DiskCacheEntryTest, DISABLED_SimpleCacheCreateDoomRace) { cache_->CreateEntry(key, &entry, net::CompletionCallback())); EXPECT_NE(null, entry); - cache_->DoomEntry(key, cb.callback()); + EXPECT_EQ(net::ERR_IO_PENDING, cache_->DoomEntry(key, cb.callback())); EXPECT_EQ(net::OK, cb.GetResult(net::ERR_IO_PENDING)); - // Lets do a Write so we block until all operations are done, so we can check - // the HasOneRef() below. This call can't be optimistic and we are checking - // that here. EXPECT_EQ( - net::ERR_IO_PENDING, + kSize1, entry->WriteData(0, 0, buffer1.get(), kSize1, cb.callback(), false)); - EXPECT_EQ(kSize1, cb.GetResult(net::ERR_IO_PENDING)); - // Check that we are not leaking. - EXPECT_TRUE( - static_cast<disk_cache::SimpleEntryImpl*>(entry)->HasOneRef()); entry->Close(); // Finish running the pending tasks so that we fully complete the close @@ -3028,12 +3017,104 @@ TEST_F(DiskCacheEntryTest, DISABLED_SimpleCacheCreateDoomRace) { for (int i = 0; i < disk_cache::kSimpleEntryFileCount; ++i) { base::FilePath entry_file_path = cache_path_.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, i)); + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, i)); base::PlatformFileInfo info; EXPECT_FALSE(file_util::GetFileInfo(entry_file_path, &info)); } } +TEST_F(DiskCacheEntryTest, SimpleCacheDoomCreateRace) { + // This test runs as APP_CACHE to make operations more synchronous. Test + // sequence: + // Create, Doom, Create. + SetCacheType(net::APP_CACHE); + SetSimpleCacheMode(); + InitCache(); + disk_cache::Entry* null = NULL; + const char key[] = "the first key"; + + net::TestCompletionCallback create_callback; + + disk_cache::Entry* entry1 = NULL; + ASSERT_EQ(net::OK, + create_callback.GetResult( + cache_->CreateEntry(key, &entry1, create_callback.callback()))); + ScopedEntryPtr entry1_closer(entry1); + EXPECT_NE(null, entry1); + + net::TestCompletionCallback doom_callback; + EXPECT_EQ(net::ERR_IO_PENDING, + cache_->DoomEntry(key, doom_callback.callback())); + + disk_cache::Entry* entry2 = NULL; + ASSERT_EQ(net::OK, + create_callback.GetResult( + cache_->CreateEntry(key, &entry2, create_callback.callback()))); + ScopedEntryPtr entry2_closer(entry2); + EXPECT_EQ(net::OK, doom_callback.GetResult(net::ERR_IO_PENDING)); +} + +TEST_F(DiskCacheEntryTest, SimpleCacheDoomDoom) { + // Test sequence: + // Create, Doom, Create, Doom (1st entry), Open. + SetSimpleCacheMode(); + InitCache(); + disk_cache::Entry* null = NULL; + + const char key[] = "the first key"; + + disk_cache::Entry* entry1 = NULL; + ASSERT_EQ(net::OK, CreateEntry(key, &entry1)); + ScopedEntryPtr entry1_closer(entry1); + EXPECT_NE(null, entry1); + + EXPECT_EQ(net::OK, DoomEntry(key)); + + disk_cache::Entry* entry2 = NULL; + ASSERT_EQ(net::OK, CreateEntry(key, &entry2)); + ScopedEntryPtr entry2_closer(entry2); + EXPECT_NE(null, entry2); + + // Redundantly dooming entry1 should not delete entry2. + disk_cache::SimpleEntryImpl* simple_entry1 = + static_cast<disk_cache::SimpleEntryImpl*>(entry1); + net::TestCompletionCallback cb; + EXPECT_EQ(net::OK, + cb.GetResult(simple_entry1->DoomEntry(cb.callback()))); + + disk_cache::Entry* entry3 = NULL; + ASSERT_EQ(net::OK, OpenEntry(key, &entry3)); + ScopedEntryPtr entry3_closer(entry3); + EXPECT_NE(null, entry3); +} + +TEST_F(DiskCacheEntryTest, SimpleCacheDoomCreateDoom) { + // Test sequence: + // Create, Doom, Create, Doom. + SetSimpleCacheMode(); + InitCache(); + + disk_cache::Entry* null = NULL; + + const char key[] = "the first key"; + + disk_cache::Entry* entry1 = NULL; + ASSERT_EQ(net::OK, CreateEntry(key, &entry1)); + ScopedEntryPtr entry1_closer(entry1); + EXPECT_NE(null, entry1); + + entry1->Doom(); + + disk_cache::Entry* entry2 = NULL; + ASSERT_EQ(net::OK, CreateEntry(key, &entry2)); + ScopedEntryPtr entry2_closer(entry2); + EXPECT_NE(null, entry2); + + entry2->Doom(); + + // This test passes if it doesn't crash. +} + // Checks that an optimistic Create would fail later on a racing Open. TEST_F(DiskCacheEntryTest, SimpleCacheOptimisticCreateFailsOnOpen) { SetSimpleCacheMode(); @@ -3078,15 +3159,16 @@ TEST_F(DiskCacheEntryTest, SimpleCacheEvictOldEntries) { scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kWriteSize)); CacheTestFillBuffer(buffer->data(), kWriteSize, false); EXPECT_EQ(kWriteSize, - WriteData(entry, 0, 0, buffer.get(), kWriteSize, false)); + WriteData(entry, 1, 0, buffer.get(), kWriteSize, false)); entry->Close(); + AddDelay(); std::string key2("the key prefix"); for (int i = 0; i < kNumExtraEntries; i++) { ASSERT_EQ(net::OK, CreateEntry(key2 + base::StringPrintf("%d", i), &entry)); ScopedEntryPtr entry_closer(entry); EXPECT_EQ(kWriteSize, - WriteData(entry, 0, 0, buffer.get(), kWriteSize, false)); + WriteData(entry, 1, 0, buffer.get(), kWriteSize, false)); } // TODO(pasko): Find a way to wait for the eviction task(s) to finish by using @@ -3122,7 +3204,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheInFlightTruncate) { ASSERT_EQ(net::OK, CreateEntry(key, &entry)); EXPECT_EQ(kBufferSize, - WriteData(entry, 0, 0, write_buffer.get(), kBufferSize, false)); + WriteData(entry, 1, 0, write_buffer.get(), kBufferSize, false)); entry->Close(); entry = NULL; @@ -3137,7 +3219,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheInFlightTruncate) { scoped_refptr<net::IOBuffer> read_buffer(new net::IOBuffer(kReadBufferSize)); CallbackTest read_callback(&helper, false); EXPECT_EQ(net::ERR_IO_PENDING, - entry->ReadData(0, + entry->ReadData(1, 0, read_buffer.get(), kReadBufferSize, @@ -3151,7 +3233,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheInFlightTruncate) { CacheTestFillBuffer(truncate_buffer->data(), kReadBufferSize, false); CallbackTest truncate_callback(&helper, false); EXPECT_EQ(net::ERR_IO_PENDING, - entry->WriteData(0, + entry->WriteData(1, 0, truncate_buffer.get(), kReadBufferSize, @@ -3191,7 +3273,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheInFlightRead) { CallbackTest write_callback(&helper, false); EXPECT_EQ(net::ERR_IO_PENDING, - entry->WriteData(0, + entry->WriteData(1, 0, write_buffer.get(), kBufferSize, @@ -3203,7 +3285,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheInFlightRead) { scoped_refptr<net::IOBuffer> read_buffer(new net::IOBuffer(kBufferSize)); CallbackTest read_callback(&helper, false); EXPECT_EQ(net::ERR_IO_PENDING, - entry->ReadData(0, + entry->ReadData(1, 0, read_buffer.get(), kBufferSize, @@ -3287,19 +3369,19 @@ TEST_F(DiskCacheEntryTest, SimpleCacheMultipleReadersCheckCRC2) { disk_cache::Entry* entry = NULL; ASSERT_EQ(net::OK, OpenEntry(key, &entry)); ScopedEntryPtr entry_closer(entry); - EXPECT_EQ(1, ReadData(entry, 0, 0, read_buffer1.get(), 1)); + EXPECT_EQ(1, ReadData(entry, 1, 0, read_buffer1.get(), 1)); // Advance the 2nd reader by the same amount. disk_cache::Entry* entry2 = NULL; EXPECT_EQ(net::OK, OpenEntry(key, &entry2)); ScopedEntryPtr entry2_closer(entry2); - EXPECT_EQ(1, ReadData(entry2, 0, 0, read_buffer2.get(), 1)); + EXPECT_EQ(1, ReadData(entry2, 1, 0, read_buffer2.get(), 1)); // Continue reading 1st. - EXPECT_GT(0, ReadData(entry, 0, 1, read_buffer1.get(), size)); + EXPECT_GT(0, ReadData(entry, 1, 1, read_buffer1.get(), size)); // This read should fail as well because we have previous read failures. - EXPECT_GT(0, ReadData(entry2, 0, 1, read_buffer2.get(), 1)); + EXPECT_GT(0, ReadData(entry2, 1, 1, read_buffer2.get(), 1)); DisableIntegrityCheck(); } @@ -3323,7 +3405,7 @@ TEST_F(DiskCacheEntryTest, SimpleCacheReadCombineCRC) { ASSERT_EQ(net::OK, CreateEntry(key, &entry)); EXPECT_NE(null, entry); - EXPECT_EQ(kSize, WriteData(entry, 0, 0, buffer1.get(), kSize, false)); + EXPECT_EQ(kSize, WriteData(entry, 1, 0, buffer1.get(), kSize, false)); entry->Close(); disk_cache::Entry* entry2 = NULL; @@ -3334,14 +3416,14 @@ TEST_F(DiskCacheEntryTest, SimpleCacheReadCombineCRC) { int offset = 0; int buf_len = kHalfSize; scoped_refptr<net::IOBuffer> buffer1_read1(new net::IOBuffer(buf_len)); - EXPECT_EQ(buf_len, ReadData(entry2, 0, offset, buffer1_read1.get(), buf_len)); + EXPECT_EQ(buf_len, ReadData(entry2, 1, offset, buffer1_read1.get(), buf_len)); EXPECT_EQ(0, memcmp(buffer1->data(), buffer1_read1->data(), buf_len)); // Read the second half of the data. offset = buf_len; buf_len = kHalfSize; scoped_refptr<net::IOBuffer> buffer1_read2(new net::IOBuffer(buf_len)); - EXPECT_EQ(buf_len, ReadData(entry2, 0, offset, buffer1_read2.get(), buf_len)); + EXPECT_EQ(buf_len, ReadData(entry2, 1, offset, buffer1_read2.get(), buf_len)); char* buffer1_data = buffer1->data() + offset; EXPECT_EQ(0, memcmp(buffer1_data, buffer1_read2->data(), buf_len)); @@ -3403,4 +3485,297 @@ TEST_F(DiskCacheEntryTest, SimpleCacheNonSequentialWrite) { entry = NULL; } +// Test that changing stream1 size does not affect stream0 (stream0 and stream1 +// are stored in the same file in Simple Cache). +TEST_F(DiskCacheEntryTest, SimpleCacheStream1SizeChanges) { + SetSimpleCacheMode(); + InitCache(); + disk_cache::Entry* entry = NULL; + const char key[] = "the key"; + const int kSize = 100; + scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer_read(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer->data(), kSize, false); + + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + EXPECT_TRUE(entry); + + // Write something into stream0. + EXPECT_EQ(kSize, WriteData(entry, 0, 0, buffer.get(), kSize, false)); + EXPECT_EQ(kSize, ReadData(entry, 0, 0, buffer_read.get(), kSize)); + EXPECT_EQ(0, memcmp(buffer->data(), buffer_read->data(), kSize)); + entry->Close(); + + // Extend stream1. + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + int stream1_size = 100; + EXPECT_EQ(0, WriteData(entry, 1, stream1_size, buffer.get(), 0, false)); + EXPECT_EQ(stream1_size, entry->GetDataSize(1)); + entry->Close(); + + // Check that stream0 data has not been modified and that the EOF record for + // stream 0 contains a crc. + // The entry needs to be reopened before checking the crc: Open will perform + // the synchronization with the previous Close. This ensures the EOF records + // have been written to disk before we attempt to read them independently. + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + base::FilePath entry_file0_path = cache_path_.AppendASCII( + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, 0)); + int flags = base::PLATFORM_FILE_READ | base::PLATFORM_FILE_OPEN; + base::PlatformFile entry_file0 = + base::CreatePlatformFile(entry_file0_path, flags, NULL, NULL); + ASSERT_TRUE(entry_file0 != base::kInvalidPlatformFileValue); + + int data_size[disk_cache::kSimpleEntryStreamCount] = {kSize, stream1_size, 0}; + disk_cache::SimpleEntryStat entry_stat( + base::Time::Now(), base::Time::Now(), data_size); + int eof_offset = entry_stat.GetEOFOffsetInFile(key, 0); + disk_cache::SimpleFileEOF eof_record; + ASSERT_EQ(static_cast<int>(sizeof(eof_record)), base::ReadPlatformFile( + entry_file0, + eof_offset, + reinterpret_cast<char*>(&eof_record), + sizeof(eof_record))); + EXPECT_EQ(disk_cache::kSimpleFinalMagicNumber, eof_record.final_magic_number); + EXPECT_TRUE((eof_record.flags & disk_cache::SimpleFileEOF::FLAG_HAS_CRC32) == + disk_cache::SimpleFileEOF::FLAG_HAS_CRC32); + + buffer_read = new net::IOBuffer(kSize); + EXPECT_EQ(kSize, ReadData(entry, 0, 0, buffer_read.get(), kSize)); + EXPECT_EQ(0, memcmp(buffer->data(), buffer_read->data(), kSize)); + + // Shrink stream1. + stream1_size = 50; + EXPECT_EQ(0, WriteData(entry, 1, stream1_size, buffer.get(), 0, true)); + EXPECT_EQ(stream1_size, entry->GetDataSize(1)); + entry->Close(); + + // Check that stream0 data has not been modified. + buffer_read = new net::IOBuffer(kSize); + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + EXPECT_EQ(kSize, ReadData(entry, 0, 0, buffer_read.get(), kSize)); + EXPECT_EQ(0, memcmp(buffer->data(), buffer_read->data(), kSize)); + entry->Close(); + entry = NULL; +} + +// Test that writing within the range for which the crc has already been +// computed will properly invalidate the computed crc. +TEST_F(DiskCacheEntryTest, SimpleCacheCRCRewrite) { + // Test sequence: + // Create, Write (big data), Write (small data in the middle), Close. + // Open, Read (all), Close. + SetSimpleCacheMode(); + InitCache(); + disk_cache::Entry* null = NULL; + const char key[] = "the first key"; + + const int kHalfSize = 200; + const int kSize = 2 * kHalfSize; + scoped_refptr<net::IOBuffer> buffer1(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer2(new net::IOBuffer(kHalfSize)); + CacheTestFillBuffer(buffer1->data(), kSize, false); + CacheTestFillBuffer(buffer2->data(), kHalfSize, false); + + disk_cache::Entry* entry = NULL; + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + EXPECT_NE(null, entry); + entry->Close(); + + for (int i = 0; i < disk_cache::kSimpleEntryStreamCount; ++i) { + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + int offset = 0; + int buf_len = kSize; + + EXPECT_EQ(buf_len, + WriteData(entry, i, offset, buffer1.get(), buf_len, false)); + offset = kHalfSize; + buf_len = kHalfSize; + EXPECT_EQ(buf_len, + WriteData(entry, i, offset, buffer2.get(), buf_len, false)); + entry->Close(); + + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + + scoped_refptr<net::IOBuffer> buffer1_read1(new net::IOBuffer(kSize)); + EXPECT_EQ(kSize, ReadData(entry, i, 0, buffer1_read1.get(), kSize)); + EXPECT_EQ(0, memcmp(buffer1->data(), buffer1_read1->data(), kHalfSize)); + EXPECT_EQ( + 0, + memcmp(buffer2->data(), buffer1_read1->data() + kHalfSize, kHalfSize)); + + entry->Close(); + } +} + +bool DiskCacheEntryTest::SimpleCacheThirdStreamFileExists(const char* key) { + int third_stream_file_index = + disk_cache::simple_util::GetFileIndexFromStreamIndex(2); + base::FilePath third_stream_file_path = cache_path_.AppendASCII( + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex( + key, third_stream_file_index)); + return PathExists(third_stream_file_path); +} + +void DiskCacheEntryTest::SyncDoomEntry(const char* key) { + net::TestCompletionCallback callback; + cache_->DoomEntry(key, callback.callback()); + callback.WaitForResult(); +} + +// Check that a newly-created entry with no third-stream writes omits the +// third stream file. +TEST_F(DiskCacheEntryTest, SimpleCacheOmittedThirdStream1) { + SetSimpleCacheMode(); + InitCache(); + + const int kHalfSize = 8; + const int kSize = kHalfSize * 2; + const char key[] = "key"; + scoped_refptr<net::IOBuffer> buffer1(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer2(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer1->data(), kHalfSize, false); + + disk_cache::Entry* entry; + + // Create entry and close without writing: third stream file should be + // omitted, since the stream is empty. + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + entry->Close(); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); + + SyncDoomEntry(key); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); +} + +// Check that a newly-created entry with only a single zero-offset, zero-length +// write omits the third stream file. +TEST_F(DiskCacheEntryTest, SimpleCacheOmittedThirdStream2) { + SetSimpleCacheMode(); + InitCache(); + + const int kHalfSize = 8; + const int kSize = kHalfSize * 2; + const char key[] = "key"; + scoped_refptr<net::IOBuffer> buffer1(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer2(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer1->data(), kHalfSize, false); + + disk_cache::Entry* entry; + int buf_len; + + // Create entry, write empty buffer to third stream, and close: third stream + // should still be omitted, since the entry ignores writes that don't modify + // data or change the length. + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + buf_len = WriteData(entry, 2, 0, buffer1, 0, true); + ASSERT_EQ(0, buf_len); + entry->Close(); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); + + SyncDoomEntry(key); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); +} + +// Check that we can read back data written to the third stream. +TEST_F(DiskCacheEntryTest, SimpleCacheOmittedThirdStream3) { + SetSimpleCacheMode(); + InitCache(); + + const int kHalfSize = 8; + const int kSize = kHalfSize * 2; + const char key[] = "key"; + scoped_refptr<net::IOBuffer> buffer1(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer2(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer1->data(), kHalfSize, false); + + disk_cache::Entry* entry; + int buf_len; + + // Create entry, write data to third stream, and close: third stream should + // not be omitted, since it contains data. Re-open entry and ensure there + // are that many bytes in the third stream. + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + buf_len = WriteData(entry, 2, 0, buffer1, kHalfSize, true); + ASSERT_EQ(kHalfSize, buf_len); + entry->Close(); + EXPECT_TRUE(SimpleCacheThirdStreamFileExists(key)); + + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + buf_len = ReadData(entry, 2, 0, buffer2, kSize); + ASSERT_EQ(buf_len, kHalfSize); + // TODO: Compare data? + entry->Close(); + EXPECT_TRUE(SimpleCacheThirdStreamFileExists(key)); + + SyncDoomEntry(key); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); +} + +// Check that we remove the third stream file upon opening an entry and finding +// the third stream empty. (This is the upgrade path for entries written +// before the third stream was optional.) +TEST_F(DiskCacheEntryTest, SimpleCacheOmittedThirdStream4) { + SetSimpleCacheMode(); + InitCache(); + + const int kHalfSize = 8; + const int kSize = kHalfSize * 2; + const char key[] = "key"; + scoped_refptr<net::IOBuffer> buffer1(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer2(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer1->data(), kHalfSize, false); + + disk_cache::Entry* entry; + int buf_len; + + // Create entry, write data to third stream, truncate third stream back to + // empty, and close: third stream will not initially be omitted, since entry + // creates the file when the first significant write comes in, and only + // removes it on open if it is empty. Reopen, ensure that the file is + // deleted, and that there's no data in the third stream. + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + buf_len = WriteData(entry, 2, 0, buffer1, kHalfSize, true); + ASSERT_EQ(kHalfSize, buf_len); + buf_len = WriteData(entry, 2, 0, buffer1, 0, true); + ASSERT_EQ(0, buf_len); + entry->Close(); + EXPECT_TRUE(SimpleCacheThirdStreamFileExists(key)); + + ASSERT_EQ(net::OK, OpenEntry(key, &entry)); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); + buf_len = ReadData(entry, 2, 0, buffer2, kSize); + ASSERT_EQ(0, buf_len); + entry->Close(); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); + + SyncDoomEntry(key); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); +} + +// Check that we don't accidentally create the third stream file once the entry +// has been doomed. +TEST_F(DiskCacheEntryTest, SimpleCacheOmittedThirdStream5) { + SetSimpleCacheMode(); + InitCache(); + + const int kHalfSize = 8; + const int kSize = kHalfSize * 2; + const char key[] = "key"; + scoped_refptr<net::IOBuffer> buffer1(new net::IOBuffer(kSize)); + scoped_refptr<net::IOBuffer> buffer2(new net::IOBuffer(kSize)); + CacheTestFillBuffer(buffer1->data(), kHalfSize, false); + + disk_cache::Entry* entry; + + // Create entry, doom entry, write data to third stream, and close: third + // stream should not exist. (Note: We don't care if the write fails, just + // that it doesn't cause the file to be created on disk.) + ASSERT_EQ(net::OK, CreateEntry(key, &entry)); + entry->Doom(); + WriteData(entry, 2, 0, buffer1, kHalfSize, true); + entry->Close(); + EXPECT_FALSE(SimpleCacheThirdStreamFileExists(key)); +} + #endif // defined(OS_POSIX) diff --git a/chromium/net/disk_cache/file.h b/chromium/net/disk_cache/file.h index 3038d884142..eb9a9ecc1e1 100644 --- a/chromium/net/disk_cache/file.h +++ b/chromium/net/disk_cache/file.h @@ -70,18 +70,20 @@ class NET_EXPORT_PRIVATE File : public base::RefCounted<File> { // Blocks until |num_pending_io| IO operations complete. static void WaitForPendingIO(int* num_pending_io); - // Drops current pending operations without waiting for them to complete. - static void DropPendingIO(); - protected: virtual ~File(); + private: // Performs the actual asynchronous write. If notify is set and there is no // callback, the call will be re-synchronized. bool AsyncWrite(const void* buffer, size_t buffer_len, size_t offset, FileIOCallback* callback, bool* completed); - private: + // Infrastructure for async IO. + int DoRead(void* buffer, size_t buffer_len, size_t offset); + int DoWrite(const void* buffer, size_t buffer_len, size_t offset); + void OnOperationComplete(FileIOCallback* callback, int result); + bool init_; bool mixed_; base::PlatformFile platform_file_; // Regular, asynchronous IO handle. diff --git a/chromium/net/disk_cache/file_posix.cc b/chromium/net/disk_cache/file_posix.cc index 2ad3db916e3..30c4a660b13 100644 --- a/chromium/net/disk_cache/file_posix.cc +++ b/chromium/net/disk_cache/file_posix.cc @@ -4,165 +4,31 @@ #include "net/disk_cache/file.h" -#include <fcntl.h> - #include "base/bind.h" +#include "base/lazy_instance.h" #include "base/location.h" #include "base/logging.h" -#include "base/threading/worker_pool.h" +#include "base/run_loop.h" +#include "base/task_runner_util.h" +#include "base/threading/sequenced_worker_pool.h" #include "net/base/net_errors.h" #include "net/disk_cache/disk_cache.h" -#include "net/disk_cache/in_flight_io.h" namespace { -// This class represents a single asynchronous IO operation while it is being -// bounced between threads. -class FileBackgroundIO : public disk_cache::BackgroundIO { - public: - // Other than the actual parameters for the IO operation (including the - // |callback| that must be notified at the end), we need the controller that - // is keeping track of all operations. When done, we notify the controller - // (we do NOT invoke the callback), in the worker thead that completed the - // operation. - FileBackgroundIO(disk_cache::File* file, const void* buf, size_t buf_len, - size_t offset, disk_cache::FileIOCallback* callback, - disk_cache::InFlightIO* controller) - : disk_cache::BackgroundIO(controller), callback_(callback), file_(file), - buf_(buf), buf_len_(buf_len), offset_(offset) { - } - - disk_cache::FileIOCallback* callback() { - return callback_; - } - - disk_cache::File* file() { - return file_; - } - - // Read and Write are the operations that can be performed asynchronously. - // The actual parameters for the operation are setup in the constructor of - // the object. Both methods should be called from a worker thread, by posting - // a task to the WorkerPool (they are RunnableMethods). When finished, - // controller->OnIOComplete() is called. - void Read(); - void Write(); - - private: - virtual ~FileBackgroundIO() {} +// The maximum number of threads for this pool. +const int kMaxThreads = 5; - disk_cache::FileIOCallback* callback_; - - disk_cache::File* file_; - const void* buf_; - size_t buf_len_; - size_t offset_; - - DISALLOW_COPY_AND_ASSIGN(FileBackgroundIO); -}; - - -// The specialized controller that keeps track of current operations. -class FileInFlightIO : public disk_cache::InFlightIO { +class FileWorkerPool : public base::SequencedWorkerPool { public: - FileInFlightIO() {} - virtual ~FileInFlightIO() {} - - // These methods start an asynchronous operation. The arguments have the same - // semantics of the File asynchronous operations, with the exception that the - // operation never finishes synchronously. - void PostRead(disk_cache::File* file, void* buf, size_t buf_len, - size_t offset, disk_cache::FileIOCallback* callback); - void PostWrite(disk_cache::File* file, const void* buf, size_t buf_len, - size_t offset, disk_cache::FileIOCallback* callback); + FileWorkerPool() : base::SequencedWorkerPool(kMaxThreads, "CachePool") {} protected: - // Invokes the users' completion callback at the end of the IO operation. - // |cancel| is true if the actual task posted to the thread is still - // queued (because we are inside WaitForPendingIO), and false if said task is - // the one performing the call. - virtual void OnOperationComplete(disk_cache::BackgroundIO* operation, - bool cancel) OVERRIDE; - - private: - DISALLOW_COPY_AND_ASSIGN(FileInFlightIO); + virtual ~FileWorkerPool() {} }; -// --------------------------------------------------------------------------- - -// Runs on a worker thread. -void FileBackgroundIO::Read() { - if (file_->Read(const_cast<void*>(buf_), buf_len_, offset_)) { - result_ = static_cast<int>(buf_len_); - } else { - result_ = net::ERR_CACHE_READ_FAILURE; - } - NotifyController(); -} - -// Runs on a worker thread. -void FileBackgroundIO::Write() { - bool rv = file_->Write(buf_, buf_len_, offset_); - - result_ = rv ? static_cast<int>(buf_len_) : net::ERR_CACHE_WRITE_FAILURE; - NotifyController(); -} - -// --------------------------------------------------------------------------- - -void FileInFlightIO::PostRead(disk_cache::File *file, void* buf, size_t buf_len, - size_t offset, disk_cache::FileIOCallback *callback) { - scoped_refptr<FileBackgroundIO> operation( - new FileBackgroundIO(file, buf, buf_len, offset, callback, this)); - file->AddRef(); // Balanced on OnOperationComplete() - - base::WorkerPool::PostTask(FROM_HERE, - base::Bind(&FileBackgroundIO::Read, operation.get()), true); - OnOperationPosted(operation.get()); -} - -void FileInFlightIO::PostWrite(disk_cache::File* file, const void* buf, - size_t buf_len, size_t offset, - disk_cache::FileIOCallback* callback) { - scoped_refptr<FileBackgroundIO> operation( - new FileBackgroundIO(file, buf, buf_len, offset, callback, this)); - file->AddRef(); // Balanced on OnOperationComplete() - - base::WorkerPool::PostTask(FROM_HERE, - base::Bind(&FileBackgroundIO::Write, operation.get()), true); - OnOperationPosted(operation.get()); -} - -// Runs on the IO thread. -void FileInFlightIO::OnOperationComplete(disk_cache::BackgroundIO* operation, - bool cancel) { - FileBackgroundIO* op = static_cast<FileBackgroundIO*>(operation); - - disk_cache::FileIOCallback* callback = op->callback(); - int bytes = operation->result(); - - // Release the references acquired in PostRead / PostWrite. - op->file()->Release(); - callback->OnFileIOComplete(bytes); -} - -// A static object tha will broker all async operations. -FileInFlightIO* s_file_operations = NULL; - -// Returns the current FileInFlightIO. -FileInFlightIO* GetFileInFlightIO() { - if (!s_file_operations) { - s_file_operations = new FileInFlightIO; - } - return s_file_operations; -} - -// Deletes the current FileInFlightIO. -void DeleteFileInFlightIO() { - DCHECK(s_file_operations); - delete s_file_operations; - s_file_operations = NULL; -} +base::LazyInstance<FileWorkerPool>::Leaky s_worker_pool = + LAZY_INSTANCE_INITIALIZER; } // namespace @@ -205,8 +71,9 @@ bool File::IsValid() const { bool File::Read(void* buffer, size_t buffer_len, size_t offset) { DCHECK(init_); if (buffer_len > static_cast<size_t>(kint32max) || - offset > static_cast<size_t>(kint32max)) + offset > static_cast<size_t>(kint32max)) { return false; + } int ret = base::ReadPlatformFile(platform_file_, offset, static_cast<char*>(buffer), buffer_len); @@ -216,8 +83,9 @@ bool File::Read(void* buffer, size_t buffer_len, size_t offset) { bool File::Write(const void* buffer, size_t buffer_len, size_t offset) { DCHECK(init_); if (buffer_len > static_cast<size_t>(kint32max) || - offset > static_cast<size_t>(kint32max)) + offset > static_cast<size_t>(kint32max)) { return false; + } int ret = base::WritePlatformFile(platform_file_, offset, static_cast<const char*>(buffer), @@ -225,9 +93,6 @@ bool File::Write(const void* buffer, size_t buffer_len, size_t offset) { return (static_cast<size_t>(ret) == buffer_len); } -// We have to increase the ref counter of the file before performing the IO to -// prevent the completion to happen with an invalid handle (if the file is -// closed while the IO is in flight). bool File::Read(void* buffer, size_t buffer_len, size_t offset, FileIOCallback* callback, bool* completed) { DCHECK(init_); @@ -237,10 +102,15 @@ bool File::Read(void* buffer, size_t buffer_len, size_t offset, return Read(buffer, buffer_len, offset); } - if (buffer_len > ULONG_MAX || offset > ULONG_MAX) + if (buffer_len > static_cast<size_t>(kint32max) || + offset > static_cast<size_t>(kint32max)) { return false; + } - GetFileInFlightIO()->PostRead(this, buffer, buffer_len, offset, callback); + base::PostTaskAndReplyWithResult( + s_worker_pool.Pointer(), FROM_HERE, + base::Bind(&File::DoRead, this, buffer, buffer_len, offset), + base::Bind(&File::OnOperationComplete, this, callback)); *completed = false; return true; @@ -255,12 +125,23 @@ bool File::Write(const void* buffer, size_t buffer_len, size_t offset, return Write(buffer, buffer_len, offset); } - return AsyncWrite(buffer, buffer_len, offset, callback, completed); + if (buffer_len > static_cast<size_t>(kint32max) || + offset > static_cast<size_t>(kint32max)) { + return false; + } + + base::PostTaskAndReplyWithResult( + s_worker_pool.Pointer(), FROM_HERE, + base::Bind(&File::DoWrite, this, buffer, buffer_len, offset), + base::Bind(&File::OnOperationComplete, this, callback)); + + *completed = false; + return true; } bool File::SetLength(size_t length) { DCHECK(init_); - if (length > ULONG_MAX) + if (length > kuint32max) return false; return base::TruncatePlatformFile(platform_file_, length); @@ -268,24 +149,22 @@ bool File::SetLength(size_t length) { size_t File::GetLength() { DCHECK(init_); - off_t ret = lseek(platform_file_, 0, SEEK_END); - if (ret < 0) - return 0; - return ret; -} + int64 len = base::SeekPlatformFile(platform_file_, + base::PLATFORM_FILE_FROM_END, 0); -// Static. -void File::WaitForPendingIO(int* num_pending_io) { - // We may be running unit tests so we should allow be able to reset the - // message loop. - GetFileInFlightIO()->WaitForPendingIO(); - DeleteFileInFlightIO(); + if (len > static_cast<int64>(kuint32max)) + return kuint32max; + + return static_cast<size_t>(len); } // Static. -void File::DropPendingIO() { - GetFileInFlightIO()->DropPendingIO(); - DeleteFileInFlightIO(); +void File::WaitForPendingIO(int* num_pending_io) { + // We are running unit tests so we should wait for all callbacks. Sadly, the + // worker pool only waits for tasks on the worker pool, not the "Reply" tasks + // so we have to let the current message loop to run. + s_worker_pool.Get().FlushForTesting(); + base::RunLoop().RunUntilIdle(); } File::~File() { @@ -293,17 +172,26 @@ File::~File() { base::ClosePlatformFile(platform_file_); } -bool File::AsyncWrite(const void* buffer, size_t buffer_len, size_t offset, - FileIOCallback* callback, bool* completed) { - DCHECK(init_); - if (buffer_len > ULONG_MAX || offset > ULONG_MAX) - return false; +// Runs on a worker thread. +int File::DoRead(void* buffer, size_t buffer_len, size_t offset) { + if (Read(const_cast<void*>(buffer), buffer_len, offset)) + return static_cast<int>(buffer_len); - GetFileInFlightIO()->PostWrite(this, buffer, buffer_len, offset, callback); + return net::ERR_CACHE_READ_FAILURE; +} - if (completed) - *completed = false; - return true; +// Runs on a worker thread. +int File::DoWrite(const void* buffer, size_t buffer_len, size_t offset) { + if (Write(const_cast<void*>(buffer), buffer_len, offset)) + return static_cast<int>(buffer_len); + + return net::ERR_CACHE_WRITE_FAILURE; +} + +// This method actually makes sure that the last reference to the file doesn't +// go away on the worker pool. +void File::OnOperationComplete(FileIOCallback* callback, int result) { + callback->OnFileIOComplete(result); } } // namespace disk_cache diff --git a/chromium/net/disk_cache/file_win.cc b/chromium/net/disk_cache/file_win.cc index f284b501045..dbb34f38341 100644 --- a/chromium/net/disk_cache/file_win.cc +++ b/chromium/net/disk_cache/file_win.cc @@ -267,9 +267,4 @@ void File::WaitForPendingIO(int* num_pending_io) { } } -// Static. -void File::DropPendingIO() { - // Nothing to do here. -} - } // namespace disk_cache diff --git a/chromium/net/disk_cache/histogram_macros.h b/chromium/net/disk_cache/histogram_macros.h index 3d8011c27f6..651bce96f2d 100644 --- a/chromium/net/disk_cache/histogram_macros.h +++ b/chromium/net/disk_cache/histogram_macros.h @@ -115,6 +115,9 @@ case net::SHADER_CACHE:\ CACHE_HISTOGRAM_##type(my_name.data(), sample);\ break;\ + case net::PNACL_CACHE:\ + CACHE_HISTOGRAM_##type(my_name.data(), sample);\ + break;\ default:\ NOTREACHED();\ break;\ diff --git a/chromium/net/disk_cache/mapped_file.cc b/chromium/net/disk_cache/mapped_file.cc index f17a1004a90..dd745ac5268 100644 --- a/chromium/net/disk_cache/mapped_file.cc +++ b/chromium/net/disk_cache/mapped_file.cc @@ -1,53 +1,12 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/disk_cache/mapped_file.h" -#include "base/files/file_path.h" -#include "base/logging.h" -#include "base/memory/scoped_ptr.h" -#include "net/disk_cache/disk_cache.h" - namespace disk_cache { -void* MappedFile::Init(const base::FilePath& name, size_t size) { - DCHECK(!init_); - if (init_ || !File::Init(name)) - return NULL; - - buffer_ = NULL; - init_ = true; - section_ = CreateFileMapping(platform_file(), NULL, PAGE_READWRITE, 0, - static_cast<DWORD>(size), NULL); - if (!section_) - return NULL; - - buffer_ = MapViewOfFile(section_, FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, size); - DCHECK(buffer_); - view_size_ = size; - - // Make sure we detect hardware failures reading the headers. - size_t temp_len = size ? size : 4096; - scoped_ptr<char[]> temp(new char[temp_len]); - if (!Read(temp.get(), temp_len, 0)) - return NULL; - - return buffer_; -} - -MappedFile::~MappedFile() { - if (!init_) - return; - - if (buffer_) { - BOOL ret = UnmapViewOfFile(buffer_); - DCHECK(ret); - } - - if (section_) - CloseHandle(section_); -} +// Note: Most of this class is implemented in platform-specific files. bool MappedFile::Load(const FileBlock* block) { size_t offset = block->offset() + view_size_; @@ -59,7 +18,18 @@ bool MappedFile::Store(const FileBlock* block) { return Write(block->buffer(), block->size(), offset); } -void MappedFile::Flush() { +bool MappedFile::Load(const FileBlock* block, + FileIOCallback* callback, + bool* completed) { + size_t offset = block->offset() + view_size_; + return Read(block->buffer(), block->size(), offset, callback, completed); +} + +bool MappedFile::Store(const FileBlock* block, + FileIOCallback* callback, + bool* completed) { + size_t offset = block->offset() + view_size_; + return Write(block->buffer(), block->size(), offset, callback, completed); } } // namespace disk_cache diff --git a/chromium/net/disk_cache/mapped_file.h b/chromium/net/disk_cache/mapped_file.h index 4649b90d1c9..ba7f4adae65 100644 --- a/chromium/net/disk_cache/mapped_file.h +++ b/chromium/net/disk_cache/mapped_file.h @@ -38,6 +38,11 @@ class NET_EXPORT_PRIVATE MappedFile : public File { bool Load(const FileBlock* block); bool Store(const FileBlock* block); + // Asynchronous versions of Load/Store, following the semantics of File::Read + // and File::Write. + bool Load(const FileBlock* block, FileIOCallback* callback, bool* completed); + bool Store(const FileBlock* block, FileIOCallback* callback, bool* completed); + // Flush the memory-mapped section to disk (synchronously). void Flush(); diff --git a/chromium/net/disk_cache/mapped_file_avoid_mmap_posix.cc b/chromium/net/disk_cache/mapped_file_avoid_mmap_posix.cc index cd514a366e2..940466ece6a 100644 --- a/chromium/net/disk_cache/mapped_file_avoid_mmap_posix.cc +++ b/chromium/net/disk_cache/mapped_file_avoid_mmap_posix.cc @@ -34,16 +34,6 @@ void* MappedFile::Init(const base::FilePath& name, size_t size) { return buffer_; } -bool MappedFile::Load(const FileBlock* block) { - size_t offset = block->offset() + view_size_; - return Read(block->buffer(), block->size(), offset); -} - -bool MappedFile::Store(const FileBlock* block) { - size_t offset = block->offset() + view_size_; - return Write(block->buffer(), block->size(), offset); -} - void MappedFile::Flush() { DCHECK(buffer_); DCHECK(snapshot_); diff --git a/chromium/net/disk_cache/mapped_file_posix.cc b/chromium/net/disk_cache/mapped_file_posix.cc index 2146245d4aa..576d02afaea 100644 --- a/chromium/net/disk_cache/mapped_file_posix.cc +++ b/chromium/net/disk_cache/mapped_file_posix.cc @@ -38,16 +38,6 @@ void* MappedFile::Init(const base::FilePath& name, size_t size) { return buffer_; } -bool MappedFile::Load(const FileBlock* block) { - size_t offset = block->offset() + view_size_; - return Read(block->buffer(), block->size(), offset); -} - -bool MappedFile::Store(const FileBlock* block) { - size_t offset = block->offset() + view_size_; - return Write(block->buffer(), block->size(), offset); -} - void MappedFile::Flush() { } diff --git a/chromium/net/disk_cache/mapped_file_unittest.cc b/chromium/net/disk_cache/mapped_file_unittest.cc index 8798db0170c..7afb8cc4e1c 100644 --- a/chromium/net/disk_cache/mapped_file_unittest.cc +++ b/chromium/net/disk_cache/mapped_file_unittest.cc @@ -7,6 +7,7 @@ #include "base/strings/string_util.h" #include "net/disk_cache/disk_cache_test_base.h" #include "net/disk_cache/disk_cache_test_util.h" +#include "net/disk_cache/file_block.h" #include "net/disk_cache/mapped_file.h" #include "testing/gtest/include/gtest/gtest.h" @@ -39,6 +40,22 @@ void FileCallbackTest::OnFileIOComplete(int bytes_copied) { helper_->CallbackWasCalled(); } +class TestFileBlock : public disk_cache::FileBlock { + public: + TestFileBlock() { + CacheTestFillBuffer(buffer_, sizeof(buffer_), false); + } + virtual ~TestFileBlock() {} + + // FileBlock interface. + virtual void* buffer() const OVERRIDE { return const_cast<char*>(buffer_); } + virtual size_t size() const OVERRIDE { return sizeof(buffer_); } + virtual int offset() const OVERRIDE { return 1024; } + + private: + char buffer_[20]; +}; + } // namespace TEST_F(DiskCacheTest, MappedFile_SyncIO) { @@ -89,3 +106,36 @@ TEST_F(DiskCacheTest, MappedFile_AsyncIO) { EXPECT_FALSE(helper.callback_reused_error()); EXPECT_STREQ(buffer1, buffer2); } + +TEST_F(DiskCacheTest, MappedFile_AsyncLoadStore) { + base::FilePath filename = cache_path_.AppendASCII("a_test"); + scoped_refptr<disk_cache::MappedFile> file(new disk_cache::MappedFile); + ASSERT_TRUE(CreateCacheTestFile(filename)); + ASSERT_TRUE(file->Init(filename, 8192)); + + int max_id = 0; + MessageLoopHelper helper; + FileCallbackTest callback(1, &helper, &max_id); + + TestFileBlock file_block1; + TestFileBlock file_block2; + base::strlcpy(static_cast<char*>(file_block1.buffer()), "the data", + file_block1.size()); + bool completed; + EXPECT_TRUE(file->Store(&file_block1, &callback, &completed)); + int expected = completed ? 0 : 1; + + max_id = 1; + helper.WaitUntilCacheIoFinished(expected); + + EXPECT_TRUE(file->Load(&file_block2, &callback, &completed)); + if (!completed) + expected++; + + helper.WaitUntilCacheIoFinished(expected); + + EXPECT_EQ(expected, helper.callbacks_called()); + EXPECT_FALSE(helper.callback_reused_error()); + EXPECT_STREQ(static_cast<char*>(file_block1.buffer()), + static_cast<char*>(file_block2.buffer())); +} diff --git a/chromium/net/disk_cache/mapped_file_win.cc b/chromium/net/disk_cache/mapped_file_win.cc index f17a1004a90..b795bf47835 100644 --- a/chromium/net/disk_cache/mapped_file_win.cc +++ b/chromium/net/disk_cache/mapped_file_win.cc @@ -49,16 +49,6 @@ MappedFile::~MappedFile() { CloseHandle(section_); } -bool MappedFile::Load(const FileBlock* block) { - size_t offset = block->offset() + view_size_; - return Read(block->buffer(), block->size(), offset); -} - -bool MappedFile::Store(const FileBlock* block) { - size_t offset = block->offset() + view_size_; - return Write(block->buffer(), block->size(), offset); -} - void MappedFile::Flush() { } diff --git a/chromium/net/disk_cache/mem_entry_impl.h b/chromium/net/disk_cache/mem_entry_impl.h index ef91f6d7b0c..b84cc39ef22 100644 --- a/chromium/net/disk_cache/mem_entry_impl.h +++ b/chromium/net/disk_cache/mem_entry_impl.h @@ -82,11 +82,7 @@ class MemEntryImpl : public Entry { return parent_ ? kChildEntry : kParentEntry; } - std::string& key() { - return key_; - } - - net::BoundNetLog& net_log() { + const net::BoundNetLog& net_log() { return net_log_; } diff --git a/chromium/net/disk_cache/simple/simple_backend_impl.cc b/chromium/net/disk_cache/simple/simple_backend_impl.cc index 2877c01f701..8856a2d7194 100644 --- a/chromium/net/disk_cache/simple/simple_backend_impl.cc +++ b/chromium/net/disk_cache/simple/simple_backend_impl.cc @@ -6,6 +6,7 @@ #include <algorithm> #include <cstdlib> +#include <functional> #if defined(OS_POSIX) #include <sys/resource.h> @@ -28,11 +29,14 @@ #include "net/disk_cache/backend_impl.h" #include "net/disk_cache/simple/simple_entry_format.h" #include "net/disk_cache/simple/simple_entry_impl.h" +#include "net/disk_cache/simple/simple_histogram_macros.h" #include "net/disk_cache/simple/simple_index.h" #include "net/disk_cache/simple/simple_index_file.h" #include "net/disk_cache/simple/simple_synchronous_entry.h" #include "net/disk_cache/simple/simple_util.h" +#include "net/disk_cache/simple/simple_version_upgrade.h" +using base::Callback; using base::Closure; using base::FilePath; using base::MessageLoopProxy; @@ -42,6 +46,8 @@ using base::Time; using base::DirectoryExists; using file_util::CreateDirectory; +namespace disk_cache { + namespace { // Maximum number of concurrent worker pool threads, which also is the limit @@ -78,7 +84,7 @@ void MaybeCreateSequencedWorkerPool() { bool g_fd_limit_histogram_has_been_populated = false; -void MaybeHistogramFdLimit() { +void MaybeHistogramFdLimit(net::CacheType cache_type) { if (g_fd_limit_histogram_has_been_populated) return; @@ -104,13 +110,14 @@ void MaybeHistogramFdLimit() { } #endif - UMA_HISTOGRAM_ENUMERATION("SimpleCache.FileDescriptorLimitStatus", - fd_limit_status, FD_LIMIT_STATUS_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "FileDescriptorLimitStatus", cache_type, + fd_limit_status, FD_LIMIT_STATUS_MAX); if (fd_limit_status == FD_LIMIT_STATUS_SUCCEEDED) { - UMA_HISTOGRAM_SPARSE_SLOWLY("SimpleCache.FileDescriptorLimitSoft", - soft_fd_limit); - UMA_HISTOGRAM_SPARSE_SLOWLY("SimpleCache.FileDescriptorLimitHard", - hard_fd_limit); + SIMPLE_CACHE_UMA(SPARSE_SLOWLY, + "FileDescriptorLimitSoft", cache_type, soft_fd_limit); + SIMPLE_CACHE_UMA(SPARSE_SLOWLY, + "FileDescriptorLimitHard", cache_type, hard_fd_limit); } g_fd_limit_histogram_has_been_populated = true; @@ -129,100 +136,112 @@ void DeleteBackendImpl(disk_cache::Backend** backend, // Detects if the files in the cache directory match the current disk cache // backend type and version. If the directory contains no cache, occupies it // with the fresh structure. -// -// There is a convention among disk cache backends: looking at the magic in the -// file "index" it should be sufficient to determine if the cache belongs to the -// currently running backend. The Simple Backend stores its index in the file -// "the-real-index" (see simple_index.cc) and the file "index" only signifies -// presence of the implementation's magic and version. There are two reasons for -// that: -// 1. Absence of the index is itself not a fatal error in the Simple Backend -// 2. The Simple Backend has pickled file format for the index making it hacky -// to have the magic in the right place. bool FileStructureConsistent(const base::FilePath& path) { if (!base::PathExists(path) && !file_util::CreateDirectory(path)) { LOG(ERROR) << "Failed to create directory: " << path.LossyDisplayName(); return false; } - const base::FilePath fake_index = path.AppendASCII("index"); - base::PlatformFileError error; - base::PlatformFile fake_index_file = base::CreatePlatformFile( - fake_index, - base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ, - NULL, - &error); - if (error == base::PLATFORM_FILE_ERROR_NOT_FOUND) { - base::PlatformFile file = base::CreatePlatformFile( - fake_index, - base::PLATFORM_FILE_CREATE | base::PLATFORM_FILE_WRITE, - NULL, &error); - disk_cache::SimpleFileHeader file_contents; - file_contents.initial_magic_number = disk_cache::kSimpleInitialMagicNumber; - file_contents.version = disk_cache::kSimpleVersion; - int bytes_written = base::WritePlatformFile( - file, 0, reinterpret_cast<char*>(&file_contents), - sizeof(file_contents)); - if (!base::ClosePlatformFile(file) || - bytes_written != sizeof(file_contents)) { - LOG(ERROR) << "Failed to write cache structure file: " - << path.LossyDisplayName(); - return false; - } - return true; - } else if (error != base::PLATFORM_FILE_OK) { - LOG(ERROR) << "Could not open cache structure file: " - << path.LossyDisplayName(); - return false; + return disk_cache::UpgradeSimpleCacheOnDisk(path); +} + +// A context used by a BarrierCompletionCallback to track state. +struct BarrierContext { + BarrierContext(int expected) + : expected(expected), + count(0), + had_error(false) {} + + const int expected; + int count; + bool had_error; +}; + +void BarrierCompletionCallbackImpl( + BarrierContext* context, + const net::CompletionCallback& final_callback, + int result) { + DCHECK_GT(context->expected, context->count); + if (context->had_error) + return; + if (result != net::OK) { + context->had_error = true; + final_callback.Run(result); + return; + } + ++context->count; + if (context->count == context->expected) + final_callback.Run(net::OK); +} + +// A barrier completion callback is a net::CompletionCallback that waits for +// |count| successful results before invoking |final_callback|. In the case of +// an error, the first error is passed to |final_callback| and all others +// are ignored. +net::CompletionCallback MakeBarrierCompletionCallback( + int count, + const net::CompletionCallback& final_callback) { + BarrierContext* context = new BarrierContext(count); + return base::Bind(&BarrierCompletionCallbackImpl, + base::Owned(context), final_callback); +} + +// A short bindable thunk that ensures a completion callback is always called +// after running an operation asynchronously. +void RunOperationAndCallback( + const Callback<int(const net::CompletionCallback&)>& operation, + const net::CompletionCallback& operation_callback) { + const int operation_result = operation.Run(operation_callback); + if (operation_result != net::ERR_IO_PENDING) + operation_callback.Run(operation_result); +} + +// A short bindable thunk that Dooms an entry if it successfully opens. +void DoomOpenedEntry(scoped_ptr<Entry*> in_entry, + const net::CompletionCallback& doom_callback, + int open_result) { + DCHECK_NE(open_result, net::ERR_IO_PENDING); + if (open_result == net::OK) { + DCHECK(in_entry); + SimpleEntryImpl* simple_entry = static_cast<SimpleEntryImpl*>(*in_entry); + const int doom_result = simple_entry->DoomEntry(doom_callback); + simple_entry->Close(); + if (doom_result != net::ERR_IO_PENDING) + doom_callback.Run(doom_result); } else { - disk_cache::SimpleFileHeader file_header; - int bytes_read = base::ReadPlatformFile( - fake_index_file, 0, reinterpret_cast<char*>(&file_header), - sizeof(file_header)); - if (!base::ClosePlatformFile(fake_index_file) || - bytes_read != sizeof(file_header) || - file_header.initial_magic_number != - disk_cache::kSimpleInitialMagicNumber || - file_header.version != disk_cache::kSimpleVersion) { - LOG(ERROR) << "File structure does not match the disk cache backend."; - return false; - } - return true; + doom_callback.Run(open_result); } } -void CallCompletionCallback(const net::CompletionCallback& callback, - int error_code) { - DCHECK(!callback.is_null()); - callback.Run(error_code); -} - -void RecordIndexLoad(base::TimeTicks constructed_since, int result) { +void RecordIndexLoad(net::CacheType cache_type, + base::TimeTicks constructed_since, + int result) { const base::TimeDelta creation_to_index = base::TimeTicks::Now() - constructed_since; - if (result == net::OK) - UMA_HISTOGRAM_TIMES("SimpleCache.CreationToIndex", creation_to_index); - else - UMA_HISTOGRAM_TIMES("SimpleCache.CreationToIndexFail", creation_to_index); + if (result == net::OK) { + SIMPLE_CACHE_UMA(TIMES, "CreationToIndex", cache_type, creation_to_index); + } else { + SIMPLE_CACHE_UMA(TIMES, + "CreationToIndexFail", cache_type, creation_to_index); + } } } // namespace -namespace disk_cache { - SimpleBackendImpl::SimpleBackendImpl(const FilePath& path, int max_bytes, - net::CacheType type, + net::CacheType cache_type, base::SingleThreadTaskRunner* cache_thread, net::NetLog* net_log) : path_(path), + cache_type_(cache_type), cache_thread_(cache_thread), orig_max_size_(max_bytes), entry_operations_mode_( - type == net::DISK_CACHE ? + cache_type == net::DISK_CACHE ? SimpleEntryImpl::OPTIMISTIC_OPERATIONS : SimpleEntryImpl::NON_OPTIMISTIC_OPERATIONS), net_log_(net_log) { - MaybeHistogramFdLimit(); + MaybeHistogramFdLimit(cache_type_); } SimpleBackendImpl::~SimpleBackendImpl() { @@ -235,13 +254,12 @@ int SimpleBackendImpl::Init(const CompletionCallback& completion_callback) { worker_pool_ = g_sequenced_worker_pool->GetTaskRunnerWithShutdownBehavior( SequencedWorkerPool::CONTINUE_ON_SHUTDOWN); - index_.reset( - new SimpleIndex(MessageLoopProxy::current().get(), - path_, - make_scoped_ptr(new SimpleIndexFile( - cache_thread_.get(), worker_pool_.get(), path_)))); - index_->ExecuteWhenReady(base::Bind(&RecordIndexLoad, - base::TimeTicks::Now())); + index_.reset(new SimpleIndex(MessageLoopProxy::current(), this, cache_type_, + make_scoped_ptr(new SimpleIndexFile( + cache_thread_.get(), worker_pool_.get(), + cache_type_, path_)))); + index_->ExecuteWhenReady( + base::Bind(&RecordIndexLoad, cache_type_, base::TimeTicks::Now())); PostTaskAndReplyWithResult( cache_thread_, @@ -266,6 +284,85 @@ void SimpleBackendImpl::OnDeactivated(const SimpleEntryImpl* entry) { active_entries_.erase(entry->entry_hash()); } +void SimpleBackendImpl::OnDoomStart(uint64 entry_hash) { + DCHECK_EQ(0u, entries_pending_doom_.count(entry_hash)); + entries_pending_doom_.insert( + std::make_pair(entry_hash, std::vector<Closure>())); +} + +void SimpleBackendImpl::OnDoomComplete(uint64 entry_hash) { + DCHECK_EQ(1u, entries_pending_doom_.count(entry_hash)); + base::hash_map<uint64, std::vector<Closure> >::iterator it = + entries_pending_doom_.find(entry_hash); + std::vector<Closure> to_run_closures; + to_run_closures.swap(it->second); + entries_pending_doom_.erase(it); + + std::for_each(to_run_closures.begin(), to_run_closures.end(), + std::mem_fun_ref(&Closure::Run)); +} + +void SimpleBackendImpl::DoomEntries(std::vector<uint64>* entry_hashes, + const net::CompletionCallback& callback) { + scoped_ptr<std::vector<uint64> > + mass_doom_entry_hashes(new std::vector<uint64>()); + mass_doom_entry_hashes->swap(*entry_hashes); + + std::vector<uint64> to_doom_individually_hashes; + + // For each of the entry hashes, there are two cases: + // 1. The entry is either open or pending doom, and so it should be doomed + // individually to avoid flakes. + // 2. The entry is not in use at all, so we can call + // SimpleSynchronousEntry::DoomEntrySet and delete the files en masse. + for (int i = mass_doom_entry_hashes->size() - 1; i >= 0; --i) { + const uint64 entry_hash = (*mass_doom_entry_hashes)[i]; + DCHECK(active_entries_.count(entry_hash) == 0 || + entries_pending_doom_.count(entry_hash) == 0) + << "The entry 0x" << std::hex << entry_hash + << " is both active and pending doom."; + if (!active_entries_.count(entry_hash) && + !entries_pending_doom_.count(entry_hash)) { + continue; + } + + to_doom_individually_hashes.push_back(entry_hash); + + (*mass_doom_entry_hashes)[i] = mass_doom_entry_hashes->back(); + mass_doom_entry_hashes->resize(mass_doom_entry_hashes->size() - 1); + } + + net::CompletionCallback barrier_callback = + MakeBarrierCompletionCallback(to_doom_individually_hashes.size() + 1, + callback); + for (std::vector<uint64>::const_iterator + it = to_doom_individually_hashes.begin(), + end = to_doom_individually_hashes.end(); it != end; ++it) { + const int doom_result = DoomEntryFromHash(*it, barrier_callback); + DCHECK_EQ(net::ERR_IO_PENDING, doom_result); + index_->Remove(*it); + } + + for (std::vector<uint64>::const_iterator it = mass_doom_entry_hashes->begin(), + end = mass_doom_entry_hashes->end(); + it != end; ++it) { + index_->Remove(*it); + OnDoomStart(*it); + } + + // Taking this pointer here avoids undefined behaviour from calling + // base::Passed before mass_doom_entry_hashes.get(). + std::vector<uint64>* mass_doom_entry_hashes_ptr = + mass_doom_entry_hashes.get(); + PostTaskAndReplyWithResult( + worker_pool_, FROM_HERE, + base::Bind(&SimpleSynchronousEntry::DoomEntrySet, + mass_doom_entry_hashes_ptr, path_), + base::Bind(&SimpleBackendImpl::DoomEntriesComplete, + AsWeakPtr(), base::Passed(&mass_doom_entry_hashes), + barrier_callback)); +} + net::CacheType SimpleBackendImpl::GetCacheType() const { return net::DISK_CACHE; } @@ -278,7 +375,22 @@ int32 SimpleBackendImpl::GetEntryCount() const { int SimpleBackendImpl::OpenEntry(const std::string& key, Entry** entry, const CompletionCallback& callback) { - scoped_refptr<SimpleEntryImpl> simple_entry = CreateOrFindActiveEntry(key); + const uint64 entry_hash = simple_util::GetEntryHashKey(key); + + // TODO(gavinp): Factor out this (not quite completely) repetitive code + // block from OpenEntry/CreateEntry/DoomEntry. + base::hash_map<uint64, std::vector<Closure> >::iterator it = + entries_pending_doom_.find(entry_hash); + if (it != entries_pending_doom_.end()) { + Callback<int(const net::CompletionCallback&)> operation = + base::Bind(&SimpleBackendImpl::OpenEntry, + base::Unretained(this), key, entry); + it->second.push_back(base::Bind(&RunOperationAndCallback, + operation, callback)); + return net::ERR_IO_PENDING; + } + scoped_refptr<SimpleEntryImpl> simple_entry = + CreateOrFindActiveEntry(entry_hash, key); CompletionCallback backend_callback = base::Bind(&SimpleBackendImpl::OnEntryOpenedFromKey, AsWeakPtr(), @@ -292,14 +404,39 @@ int SimpleBackendImpl::OpenEntry(const std::string& key, int SimpleBackendImpl::CreateEntry(const std::string& key, Entry** entry, const CompletionCallback& callback) { - DCHECK(key.size() > 0); - scoped_refptr<SimpleEntryImpl> simple_entry = CreateOrFindActiveEntry(key); + DCHECK_LT(0u, key.size()); + const uint64 entry_hash = simple_util::GetEntryHashKey(key); + + base::hash_map<uint64, std::vector<Closure> >::iterator it = + entries_pending_doom_.find(entry_hash); + if (it != entries_pending_doom_.end()) { + Callback<int(const net::CompletionCallback&)> operation = + base::Bind(&SimpleBackendImpl::CreateEntry, + base::Unretained(this), key, entry); + it->second.push_back(base::Bind(&RunOperationAndCallback, + operation, callback)); + return net::ERR_IO_PENDING; + } + scoped_refptr<SimpleEntryImpl> simple_entry = + CreateOrFindActiveEntry(entry_hash, key); return simple_entry->CreateEntry(entry, callback); } int SimpleBackendImpl::DoomEntry(const std::string& key, const net::CompletionCallback& callback) { - scoped_refptr<SimpleEntryImpl> simple_entry = CreateOrFindActiveEntry(key); + const uint64 entry_hash = simple_util::GetEntryHashKey(key); + + base::hash_map<uint64, std::vector<Closure> >::iterator it = + entries_pending_doom_.find(entry_hash); + if (it != entries_pending_doom_.end()) { + Callback<int(const net::CompletionCallback&)> operation = + base::Bind(&SimpleBackendImpl::DoomEntry, base::Unretained(this), key); + it->second.push_back(base::Bind(&RunOperationAndCallback, + operation, callback)); + return net::ERR_IO_PENDING; + } + scoped_refptr<SimpleEntryImpl> simple_entry = + CreateOrFindActiveEntry(entry_hash, key); return simple_entry->DoomEntry(callback); } @@ -316,28 +453,8 @@ void SimpleBackendImpl::IndexReadyForDoom(Time initial_time, return; } scoped_ptr<std::vector<uint64> > removed_key_hashes( - index_->RemoveEntriesBetween(initial_time, end_time).release()); - - // If any of the entries we are dooming are currently open, we need to remove - // them from |active_entries_|, so that attempts to create new entries will - // succeed and attempts to open them will fail. - for (int i = removed_key_hashes->size() - 1; i >= 0; --i) { - const uint64 entry_hash = (*removed_key_hashes)[i]; - EntryMap::iterator it = active_entries_.find(entry_hash); - if (it == active_entries_.end()) - continue; - SimpleEntryImpl* entry = it->second.get(); - entry->Doom(); - - (*removed_key_hashes)[i] = removed_key_hashes->back(); - removed_key_hashes->resize(removed_key_hashes->size() - 1); - } - - PostTaskAndReplyWithResult( - worker_pool_, FROM_HERE, - base::Bind(&SimpleSynchronousEntry::DoomEntrySet, - base::Passed(&removed_key_hashes), path_), - base::Bind(&CallCompletionCallback, callback)); + index_->GetEntriesBetween(initial_time, end_time).release()); + DoomEntries(removed_key_hashes.get(), callback); } int SimpleBackendImpl::DoomEntriesBetween( @@ -380,7 +497,7 @@ void SimpleBackendImpl::GetStats( } void SimpleBackendImpl::OnExternalCacheHit(const std::string& key) { - index_->UseIfExists(key); + index_->UseIfExists(simple_util::GetEntryHashKey(key)); } void SimpleBackendImpl::InitializeIndex(const CompletionCallback& callback, @@ -421,9 +538,9 @@ SimpleBackendImpl::DiskStatResult SimpleBackendImpl::InitCacheStructureOnDisk( } scoped_refptr<SimpleEntryImpl> SimpleBackendImpl::CreateOrFindActiveEntry( + const uint64 entry_hash, const std::string& key) { - const uint64 entry_hash = simple_util::GetEntryHashKey(key); - + DCHECK_EQ(entry_hash, simple_util::GetEntryHashKey(key)); std::pair<EntryMap::iterator, bool> insert_result = active_entries_.insert(std::make_pair(entry_hash, base::WeakPtr<SimpleEntryImpl>())); @@ -432,7 +549,7 @@ scoped_refptr<SimpleEntryImpl> SimpleBackendImpl::CreateOrFindActiveEntry( DCHECK(!it->second.get()); if (!it->second.get()) { SimpleEntryImpl* entry = new SimpleEntryImpl( - path_, entry_hash, entry_operations_mode_, this, net_log_); + cache_type_, path_, entry_hash, entry_operations_mode_, this, net_log_); entry->SetKey(key); it->second = entry->AsWeakPtr(); } @@ -442,34 +559,73 @@ scoped_refptr<SimpleEntryImpl> SimpleBackendImpl::CreateOrFindActiveEntry( if (key != it->second->key()) { it->second->Doom(); DCHECK_EQ(0U, active_entries_.count(entry_hash)); - return CreateOrFindActiveEntry(key); + return CreateOrFindActiveEntry(entry_hash, key); } return make_scoped_refptr(it->second.get()); } -int SimpleBackendImpl::OpenEntryFromHash(uint64 hash, +int SimpleBackendImpl::OpenEntryFromHash(uint64 entry_hash, Entry** entry, const CompletionCallback& callback) { - EntryMap::iterator has_active = active_entries_.find(hash); - if (has_active != active_entries_.end()) + base::hash_map<uint64, std::vector<Closure> >::iterator it = + entries_pending_doom_.find(entry_hash); + if (it != entries_pending_doom_.end()) { + Callback<int(const net::CompletionCallback&)> operation = + base::Bind(&SimpleBackendImpl::OpenEntryFromHash, + base::Unretained(this), entry_hash, entry); + it->second.push_back(base::Bind(&RunOperationAndCallback, + operation, callback)); + return net::ERR_IO_PENDING; + } + + EntryMap::iterator has_active = active_entries_.find(entry_hash); + if (has_active != active_entries_.end()) { return OpenEntry(has_active->second->key(), entry, callback); + } - scoped_refptr<SimpleEntryImpl> simple_entry = - new SimpleEntryImpl(path_, hash, entry_operations_mode_, this, net_log_); + scoped_refptr<SimpleEntryImpl> simple_entry = new SimpleEntryImpl( + cache_type_, path_, entry_hash, entry_operations_mode_, this, net_log_); CompletionCallback backend_callback = base::Bind(&SimpleBackendImpl::OnEntryOpenedFromHash, - AsWeakPtr(), - hash, entry, simple_entry, callback); + AsWeakPtr(), entry_hash, entry, simple_entry, callback); return simple_entry->OpenEntry(entry, backend_callback); } +int SimpleBackendImpl::DoomEntryFromHash(uint64 entry_hash, + const CompletionCallback& callback) { + Entry** entry = new Entry*(); + scoped_ptr<Entry*> scoped_entry(entry); + + base::hash_map<uint64, std::vector<Closure> >::iterator it = + entries_pending_doom_.find(entry_hash); + if (it != entries_pending_doom_.end()) { + Callback<int(const net::CompletionCallback&)> operation = + base::Bind(&SimpleBackendImpl::DoomEntryFromHash, + base::Unretained(this), entry_hash); + it->second.push_back(base::Bind(&RunOperationAndCallback, + operation, callback)); + return net::ERR_IO_PENDING; + } + + EntryMap::iterator active_it = active_entries_.find(entry_hash); + if (active_it != active_entries_.end()) + return active_it->second->DoomEntry(callback); + + // There's no pending dooms, nor any open entry. We can make a trivial + // call to DoomEntries() to delete this entry. + std::vector<uint64> entry_hash_vector; + entry_hash_vector.push_back(entry_hash); + DoomEntries(&entry_hash_vector, callback); + return net::ERR_IO_PENDING; +} + void SimpleBackendImpl::GetNextEntryInIterator( void** iter, Entry** next_entry, const CompletionCallback& callback, int error_code) { if (error_code != net::OK) { - CallCompletionCallback(callback, error_code); + callback.Run(error_code); return; } if (*iter == NULL) { @@ -494,12 +650,12 @@ void SimpleBackendImpl::GetNextEntryInIterator( if (error_code_open == net::ERR_IO_PENDING) return; if (error_code_open != net::ERR_FAILED) { - CallCompletionCallback(callback, error_code_open); + callback.Run(error_code_open); return; } } } - CallCompletionCallback(callback, net::ERR_FAILED); + callback.Run(net::ERR_FAILED); } void SimpleBackendImpl::OnEntryOpenedFromHash( @@ -509,7 +665,7 @@ void SimpleBackendImpl::OnEntryOpenedFromHash( const CompletionCallback& callback, int error_code) { if (error_code != net::OK) { - CallCompletionCallback(callback, error_code); + callback.Run(error_code); return; } DCHECK(*entry); @@ -522,7 +678,7 @@ void SimpleBackendImpl::OnEntryOpenedFromHash( // There is no active entry corresponding to this hash. The entry created // is put in the map of active entries and returned to the caller. it->second = simple_entry->AsWeakPtr(); - CallCompletionCallback(callback, error_code); + callback.Run(error_code); } else { // The entry was made active with the key while the creation from hash // occurred. The entry created from hash needs to be closed, and the one @@ -550,9 +706,9 @@ void SimpleBackendImpl::OnEntryOpenedFromKey( } else { DCHECK_EQ(simple_entry->entry_hash(), simple_util::GetEntryHashKey(key)); } - UMA_HISTOGRAM_BOOLEAN("SimpleCache.KeyMatchedOnOpen", key_matches); + SIMPLE_CACHE_UMA(BOOLEAN, "KeyMatchedOnOpen", cache_type_, key_matches); } - CallCompletionCallback(callback, final_code); + callback.Run(final_code); } void SimpleBackendImpl::CheckIterationReturnValue( @@ -564,7 +720,23 @@ void SimpleBackendImpl::CheckIterationReturnValue( OpenNextEntry(iter, entry, callback); return; } - CallCompletionCallback(callback, error_code); + callback.Run(error_code); +} + +void SimpleBackendImpl::DoomEntriesComplete( + scoped_ptr<std::vector<uint64> > entry_hashes, + const net::CompletionCallback& callback, + int result) { + std::for_each( + entry_hashes->begin(), entry_hashes->end(), + std::bind1st(std::mem_fun(&SimpleBackendImpl::OnDoomComplete), + this)); + callback.Run(result); +} + +void SimpleBackendImpl::FlushWorkerPoolForTesting() { + if (g_sequenced_worker_pool) + g_sequenced_worker_pool->FlushForTesting(); } } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_backend_impl.h b/chromium/net/disk_cache/simple/simple_backend_impl.h index 4f01351752e..eb14de8841a 100644 --- a/chromium/net/disk_cache/simple/simple_backend_impl.h +++ b/chromium/net/disk_cache/simple/simple_backend_impl.h @@ -9,6 +9,7 @@ #include <utility> #include <vector> +#include "base/callback_forward.h" #include "base/compiler_specific.h" #include "base/containers/hash_tables.h" #include "base/files/file_path.h" @@ -20,6 +21,7 @@ #include "net/base/cache_type.h" #include "net/disk_cache/disk_cache.h" #include "net/disk_cache/simple/simple_entry_impl.h" +#include "net/disk_cache/simple/simple_index_delegate.h" namespace base { class SingleThreadTaskRunner; @@ -39,15 +41,17 @@ class SimpleEntryImpl; class SimpleIndex; class NET_EXPORT_PRIVATE SimpleBackendImpl : public Backend, + public SimpleIndexDelegate, public base::SupportsWeakPtr<SimpleBackendImpl> { public: SimpleBackendImpl(const base::FilePath& path, int max_bytes, - net::CacheType type, + net::CacheType cache_type, base::SingleThreadTaskRunner* cache_thread, net::NetLog* net_log); virtual ~SimpleBackendImpl(); + net::CacheType cache_type() const { return cache_type_; } SimpleIndex* index() { return index_.get(); } base::TaskRunner* worker_pool() { return worker_pool_.get(); } @@ -64,6 +68,22 @@ class NET_EXPORT_PRIVATE SimpleBackendImpl : public Backend, // operations to construct a new object. void OnDeactivated(const SimpleEntryImpl* entry); + // Flush our SequencedWorkerPool. + static void FlushWorkerPoolForTesting(); + + // The entry for |entry_hash| is being doomed; the backend will not attempt + // run new operations for this |entry_hash| until the Doom is completed. + void OnDoomStart(uint64 entry_hash); + + // The entry for |entry_hash| has been successfully doomed, we can now allow + // operations on this entry, and we can run any operations enqueued while the + // doom completed. + void OnDoomComplete(uint64 entry_hash); + + // SimpleIndexDelegate: + virtual void DoomEntries(std::vector<uint64>* entry_hashes, + const CompletionCallback& callback) OVERRIDE; + // Backend: virtual net::CacheType GetCacheType() const OVERRIDE; virtual int32 GetEntryCount() const OVERRIDE; @@ -118,16 +138,22 @@ class NET_EXPORT_PRIVATE SimpleBackendImpl : public Backend, // Searches |active_entries_| for the entry corresponding to |key|. If found, // returns the found entry. Otherwise, creates a new entry and returns that. scoped_refptr<SimpleEntryImpl> CreateOrFindActiveEntry( + uint64 entry_hash, const std::string& key); // Given a hash, will try to open the corresponding Entry. If we have an Entry // corresponding to |hash| in the map of active entries, opens it. Otherwise, // a new empty Entry will be created, opened and filled with information from // the disk. - int OpenEntryFromHash(uint64 hash, + int OpenEntryFromHash(uint64 entry_hash, Entry** entry, const CompletionCallback& callback); + // Doom the entry corresponding to |entry_hash|, if it's active or currently + // pending doom. This function does not block if there is an active entry, + // which is very important to prevent races in DoomEntries() above. + int DoomEntryFromHash(uint64 entry_hash, const CompletionCallback & callback); + // Called when the index is initilized to find the next entry in the iterator // |iter|. If there are no more hashes in the iterator list, net::ERR_FAILED // is returned. Otherwise, calls OpenEntryFromHash. @@ -162,7 +188,14 @@ class NET_EXPORT_PRIVATE SimpleBackendImpl : public Backend, const CompletionCallback& callback, int error_code); + // A callback thunk used by DoomEntries to clear the |entries_pending_doom_| + // after a mass doom. + void DoomEntriesComplete(scoped_ptr<std::vector<uint64> > entry_hashes, + const CompletionCallback& callback, + int result); + const base::FilePath path_; + const net::CacheType cache_type_; scoped_ptr<SimpleIndex> index_; const scoped_refptr<base::SingleThreadTaskRunner> cache_thread_; scoped_refptr<base::TaskRunner> worker_pool_; @@ -170,10 +203,14 @@ class NET_EXPORT_PRIVATE SimpleBackendImpl : public Backend, int orig_max_size_; const SimpleEntryImpl::OperationsMode entry_operations_mode_; - // TODO(gavinp): Store the entry_hash in SimpleEntryImpl, and index this map - // by hash. This will save memory, and make IndexReadyForDoom easier. EntryMap active_entries_; + // The set of all entries which are currently being doomed. To avoid races, + // these entries cannot have Doom/Create/Open operations run until the doom + // is complete. The base::Closure map target is used to store deferred + // operations to be run at the completion of the Doom. + base::hash_map<uint64, std::vector<base::Closure> > entries_pending_doom_; + net::NetLog* const net_log_; }; diff --git a/chromium/net/disk_cache/simple/simple_backend_version.h b/chromium/net/disk_cache/simple/simple_backend_version.h new file mode 100644 index 00000000000..fe350a2e886 --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_backend_version.h @@ -0,0 +1,27 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DISK_CACHE_SIMPLE_SIMPLE_BACKEND_VERSION_H_ +#define NET_DISK_CACHE_SIMPLE_SIMPLE_BACKEND_VERSION_H_ + +namespace disk_cache { + +// Short rules helping to think about data upgrades within Simple Cache: +// * ALL changes of on-disk data format, backward-compatible or not, +// forward-compatible or not, require updating the |kSimpleVersion|. +// * All cache Upgrades are performed on backend start, must be finished +// before the new backend starts processing any incoming operations. +// * If the Upgrade is not implemented for transition from +// |kSimpleVersion - 1| then the whole cache directory will be cleared. +// * Dropping cache data on disk or some of its parts can be a valid way to +// Upgrade. +const uint32 kSimpleVersion = 6; + +// The version of the entry file(s) as written to disk. Must be updated iff the +// entry format changes with the overall backend version update. +const uint32 kSimpleEntryVersionOnDisk = 5; + +} // namespace disk_cache + +#endif // NET_DISK_CACHE_SIMPLE_SIMPLE_BACKEND_VERSION_H_ diff --git a/chromium/net/disk_cache/simple/simple_entry_format.h b/chromium/net/disk_cache/simple/simple_entry_format.h index d06ab1139c5..8224b858dcc 100644 --- a/chromium/net/disk_cache/simple/simple_entry_format.h +++ b/chromium/net/disk_cache/simple/simple_entry_format.h @@ -19,17 +19,21 @@ namespace disk_cache { const uint64 kSimpleInitialMagicNumber = GG_UINT64_C(0xfcfb6d1ba7725c30); const uint64 kSimpleFinalMagicNumber = GG_UINT64_C(0xf4fa6f45970d41d8); -// A file in the Simple cache consists of a SimpleFileHeader followed -// by data. +// A file containing stream 0 and stream 1 in the Simple cache consists of: +// - a SimpleFileHeader. +// - the key. +// - the data from stream 1. +// - a SimpleFileEOF record for stream 1. +// - the data from stream 0. +// - a SimpleFileEOF record for stream 0. -// A file in the Simple cache consists of: +// A file containing stream 2 in the Simple cache consists of: // - a SimpleFileHeader. // - the key. // - the data. // - at the end, a SimpleFileEOF record. -const uint32 kSimpleVersion = 4; - -static const int kSimpleEntryFileCount = 3; +static const int kSimpleEntryFileCount = 2; +static const int kSimpleEntryStreamCount = 3; struct NET_EXPORT_PRIVATE SimpleFileHeader { SimpleFileHeader(); @@ -40,7 +44,7 @@ struct NET_EXPORT_PRIVATE SimpleFileHeader { uint32 key_hash; }; -struct SimpleFileEOF { +struct NET_EXPORT_PRIVATE SimpleFileEOF { enum Flags { FLAG_HAS_CRC32 = (1U << 0), }; @@ -50,6 +54,8 @@ struct SimpleFileEOF { uint64 final_magic_number; uint32 flags; uint32 data_crc32; + // |stream_size| is only used in the EOF record for stream 0. + uint32 stream_size; }; } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_entry_format_history.h b/chromium/net/disk_cache/simple/simple_entry_format_history.h new file mode 100644 index 00000000000..f7b818a59e9 --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_entry_format_history.h @@ -0,0 +1,62 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DISK_CACHE_SIMPLE_SIMPLE_ENTRY_FORMAT_HISTORY_H_ +#define NET_DISK_CACHE_SIMPLE_SIMPLE_ENTRY_FORMAT_HISTORY_H_ + +#include "base/basictypes.h" +#include "base/port.h" +#include "net/base/net_export.h" + +namespace disk_cache { + +namespace simplecache_v5 { + +const uint64 kSimpleInitialMagicNumber = GG_UINT64_C(0xfcfb6d1ba7725c30); +const uint64 kSimpleFinalMagicNumber = GG_UINT64_C(0xf4fa6f45970d41d8); + +// A file containing stream 0 and stream 1 in the Simple cache consists of: +// - a SimpleFileHeader. +// - the key. +// - the data from stream 1. +// - a SimpleFileEOF record for stream 1. +// - the data from stream 0. +// - a SimpleFileEOF record for stream 0. + +// A file containing stream 2 in the Simple cache consists of: +// - a SimpleFileHeader. +// - the key. +// - the data. +// - at the end, a SimpleFileEOF record. +static const int kSimpleEntryFileCount = 2; +static const int kSimpleEntryStreamCount = 3; + +struct NET_EXPORT_PRIVATE SimpleFileHeader { + SimpleFileHeader(); + + uint64 initial_magic_number; + uint32 version; + uint32 key_length; + uint32 key_hash; +}; + +struct NET_EXPORT_PRIVATE SimpleFileEOF { + enum Flags { + FLAG_HAS_CRC32 = (1U << 0), + }; + + SimpleFileEOF(); + + uint64 final_magic_number; + uint32 flags; + uint32 data_crc32; + // |stream_size| is only used in the EOF record for stream 0. + uint32 stream_size; +}; + +} // namespace simplecache_v5 + +} // namespace disk_cache + +#endif // NET_DISK_CACHE_SIMPLE_SIMPLE_ENTRY_FORMAT_HISTORY_H_ diff --git a/chromium/net/disk_cache/simple/simple_entry_impl.cc b/chromium/net/disk_cache/simple/simple_entry_impl.cc index cbee6048371..3d2bc22cc36 100644 --- a/chromium/net/disk_cache/simple/simple_entry_impl.cc +++ b/chromium/net/disk_cache/simple/simple_entry_impl.cc @@ -14,19 +14,21 @@ #include "base/location.h" #include "base/logging.h" #include "base/message_loop/message_loop_proxy.h" -#include "base/metrics/histogram.h" #include "base/task_runner.h" +#include "base/task_runner_util.h" #include "base/time/time.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/disk_cache/net_log_parameters.h" #include "net/disk_cache/simple/simple_backend_impl.h" +#include "net/disk_cache/simple/simple_histogram_macros.h" #include "net/disk_cache/simple/simple_index.h" #include "net/disk_cache/simple/simple_net_log_parameters.h" #include "net/disk_cache/simple/simple_synchronous_entry.h" #include "net/disk_cache/simple/simple_util.h" #include "third_party/zlib/zlib.h" +namespace disk_cache { namespace { // Used in histograms, please only add entries at the end. @@ -48,7 +50,8 @@ enum WriteResult { WRITE_RESULT_OVER_MAX_SIZE = 2, WRITE_RESULT_BAD_STATE = 3, WRITE_RESULT_SYNC_WRITE_FAILURE = 4, - WRITE_RESULT_MAX = 5, + WRITE_RESULT_FAST_EMPTY_RETURN = 5, + WRITE_RESULT_MAX = 6, }; // Used in histograms, please only add entries at the end. @@ -61,21 +64,23 @@ enum HeaderSizeChange { HEADER_SIZE_CHANGE_MAX }; -void RecordReadResult(ReadResult result) { - UMA_HISTOGRAM_ENUMERATION("SimpleCache.ReadResult", result, READ_RESULT_MAX); -}; +void RecordReadResult(net::CacheType cache_type, ReadResult result) { + SIMPLE_CACHE_UMA(ENUMERATION, + "ReadResult", cache_type, result, READ_RESULT_MAX); +} -void RecordWriteResult(WriteResult result) { - UMA_HISTOGRAM_ENUMERATION("SimpleCache.WriteResult", - result, WRITE_RESULT_MAX); -}; +void RecordWriteResult(net::CacheType cache_type, WriteResult result) { + SIMPLE_CACHE_UMA(ENUMERATION, + "WriteResult2", cache_type, result, WRITE_RESULT_MAX); +} // TODO(ttuttle): Consider removing this once we have a good handle on header // size changes. -void RecordHeaderSizeChange(int old_size, int new_size) { +void RecordHeaderSizeChange(net::CacheType cache_type, + int old_size, int new_size) { HeaderSizeChange size_change; - UMA_HISTOGRAM_COUNTS_10000("SimpleCache.HeaderSize", new_size); + SIMPLE_CACHE_UMA(COUNTS_10000, "HeaderSize", cache_type, new_size); if (old_size == 0) { size_change = HEADER_SIZE_CHANGE_INITIAL; @@ -83,51 +88,52 @@ void RecordHeaderSizeChange(int old_size, int new_size) { size_change = HEADER_SIZE_CHANGE_SAME; } else if (new_size > old_size) { int delta = new_size - old_size; - UMA_HISTOGRAM_COUNTS_10000("SimpleCache.HeaderSizeIncreaseAbsolute", - delta); - UMA_HISTOGRAM_PERCENTAGE("SimpleCache.HeaderSizeIncreasePercentage", - delta * 100 / old_size); + SIMPLE_CACHE_UMA(COUNTS_10000, + "HeaderSizeIncreaseAbsolute", cache_type, delta); + SIMPLE_CACHE_UMA(PERCENTAGE, + "HeaderSizeIncreasePercentage", cache_type, + delta * 100 / old_size); size_change = HEADER_SIZE_CHANGE_INCREASE; } else { // new_size < old_size int delta = old_size - new_size; - UMA_HISTOGRAM_COUNTS_10000("SimpleCache.HeaderSizeDecreaseAbsolute", - delta); - UMA_HISTOGRAM_PERCENTAGE("SimpleCache.HeaderSizeDecreasePercentage", - delta * 100 / old_size); + SIMPLE_CACHE_UMA(COUNTS_10000, + "HeaderSizeDecreaseAbsolute", cache_type, delta); + SIMPLE_CACHE_UMA(PERCENTAGE, + "HeaderSizeDecreasePercentage", cache_type, + delta * 100 / old_size); size_change = HEADER_SIZE_CHANGE_DECREASE; } - UMA_HISTOGRAM_ENUMERATION("SimpleCache.HeaderSizeChange", - size_change, - HEADER_SIZE_CHANGE_MAX); -} - -void RecordUnexpectedStream0Write() { - UMA_HISTOGRAM_ENUMERATION("SimpleCache.HeaderSizeChange", - HEADER_SIZE_CHANGE_UNEXPECTED_WRITE, - HEADER_SIZE_CHANGE_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "HeaderSizeChange", cache_type, + size_change, HEADER_SIZE_CHANGE_MAX); } -// Short trampoline to take an owned input parameter and call a net completion -// callback with its value. -void CallCompletionCallback(const net::CompletionCallback& callback, - scoped_ptr<int> result) { - DCHECK(result); - if (!callback.is_null()) - callback.Run(*result); +void RecordUnexpectedStream0Write(net::CacheType cache_type) { + SIMPLE_CACHE_UMA(ENUMERATION, + "HeaderSizeChange", cache_type, + HEADER_SIZE_CHANGE_UNEXPECTED_WRITE, HEADER_SIZE_CHANGE_MAX); } int g_open_entry_count = 0; -void AdjustOpenEntryCountBy(int offset) { +void AdjustOpenEntryCountBy(net::CacheType cache_type, int offset) { g_open_entry_count += offset; - UMA_HISTOGRAM_COUNTS_10000("SimpleCache.GlobalOpenEntryCount", - g_open_entry_count); + SIMPLE_CACHE_UMA(COUNTS_10000, + "GlobalOpenEntryCount", cache_type, g_open_entry_count); } -} // namespace +void InvokeCallbackIfBackendIsAlive( + const base::WeakPtr<SimpleBackendImpl>& backend, + const net::CompletionCallback& completion_callback, + int result) { + DCHECK(!completion_callback.is_null()); + if (!backend.get()) + return; + completion_callback.Run(result); +} -namespace disk_cache { +} // namespace using base::Closure; using base::FilePath; @@ -150,12 +156,14 @@ class SimpleEntryImpl::ScopedOperationRunner { SimpleEntryImpl* const entry_; }; -SimpleEntryImpl::SimpleEntryImpl(const FilePath& path, +SimpleEntryImpl::SimpleEntryImpl(net::CacheType cache_type, + const FilePath& path, const uint64 entry_hash, OperationsMode operations_mode, SimpleBackendImpl* backend, net::NetLog* net_log) : backend_(backend->AsWeakPtr()), + cache_type_(cache_type), worker_pool_(backend->worker_pool()), path_(path), entry_hash_(entry_hash), @@ -163,10 +171,12 @@ SimpleEntryImpl::SimpleEntryImpl(const FilePath& path, last_used_(Time::Now()), last_modified_(last_used_), open_count_(0), + doomed_(false), state_(STATE_UNINITIALIZED), synchronous_entry_(NULL), net_log_(net::BoundNetLog::Make( - net_log, net::NetLog::SOURCE_DISK_CACHE_ENTRY)) { + net_log, net::NetLog::SOURCE_DISK_CACHE_ENTRY)), + stream_0_data_(new net::GrowableIOBuffer()) { COMPILE_ASSERT(arraysize(data_size_) == arraysize(crc32s_end_offset_), arrays_should_be_same_size); COMPILE_ASSERT(arraysize(data_size_) == arraysize(crc32s_), @@ -201,8 +211,9 @@ int SimpleEntryImpl::OpenEntry(Entry** out_entry, else open_entry_index_enum = INDEX_MISS; } - UMA_HISTOGRAM_ENUMERATION("SimpleCache.OpenEntryIndexState", - open_entry_index_enum, INDEX_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "OpenEntryIndexState", cache_type_, + open_entry_index_enum, INDEX_MAX); // If entry is not known to the index, initiate fast failover to the network. if (open_entry_index_enum == INDEX_MISS) { @@ -246,23 +257,23 @@ int SimpleEntryImpl::CreateEntry(Entry** out_entry, // have the entry in the index but we don't have the created files yet, this // way we never leak files. CreationOperationComplete will remove the entry // from the index if the creation fails. - backend_->index()->Insert(key_); + backend_->index()->Insert(entry_hash_); RunNextOperationIfNeeded(); return ret_value; } int SimpleEntryImpl::DoomEntry(const CompletionCallback& callback) { + if (doomed_) + return net::OK; net_log_.AddEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_DOOM_CALL); net_log_.AddEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_DOOM_BEGIN); MarkAsDoomed(); - scoped_ptr<int> result(new int()); - Closure task = base::Bind(&SimpleSynchronousEntry::DoomEntry, path_, key_, - entry_hash_, result.get()); - Closure reply = base::Bind(&CallCompletionCallback, - callback, base::Passed(&result)); - worker_pool_->PostTaskAndReply(FROM_HERE, task, reply); + if (backend_.get()) + backend_->OnDoomStart(entry_hash_); + pending_operations_.push(SimpleEntryOperation::DoomOperation(this, callback)); + RunNextOperationIfNeeded(); return net::ERR_IO_PENDING; } @@ -328,14 +339,14 @@ int SimpleEntryImpl::ReadData(int stream_index, false)); } - if (stream_index < 0 || stream_index >= kSimpleEntryFileCount || + if (stream_index < 0 || stream_index >= kSimpleEntryStreamCount || buf_len < 0) { if (net_log_.IsLoggingAllEvents()) { net_log_.AddEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_READ_END, CreateNetLogReadWriteCompleteCallback(net::ERR_INVALID_ARGUMENT)); } - RecordReadResult(READ_RESULT_INVALID_ARGUMENT); + RecordReadResult(cache_type_, READ_RESULT_INVALID_ARGUMENT); return net::ERR_INVALID_ARGUMENT; } if (pending_operations_.empty() && (offset >= GetDataSize(stream_index) || @@ -345,10 +356,12 @@ int SimpleEntryImpl::ReadData(int stream_index, CreateNetLogReadWriteCompleteCallback(0)); } - RecordReadResult(READ_RESULT_NONBLOCK_EMPTY_RETURN); + RecordReadResult(cache_type_, READ_RESULT_NONBLOCK_EMPTY_RETURN); return 0; } + // TODO(clamy): return immediatly when reading from stream 0. + // TODO(felipeg): Optimization: Add support for truly parallel read // operations. bool alone_in_queue = @@ -374,14 +387,14 @@ int SimpleEntryImpl::WriteData(int stream_index, truncate)); } - if (stream_index < 0 || stream_index >= kSimpleEntryFileCount || offset < 0 || - buf_len < 0) { + if (stream_index < 0 || stream_index >= kSimpleEntryStreamCount || + offset < 0 || buf_len < 0) { if (net_log_.IsLoggingAllEvents()) { net_log_.AddEvent( net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_WRITE_END, CreateNetLogReadWriteCompleteCallback(net::ERR_INVALID_ARGUMENT)); } - RecordWriteResult(WRITE_RESULT_INVALID_ARGUMENT); + RecordWriteResult(cache_type_, WRITE_RESULT_INVALID_ARGUMENT); return net::ERR_INVALID_ARGUMENT; } if (backend_.get() && offset + buf_len > backend_->GetMaxFileSize()) { @@ -390,21 +403,16 @@ int SimpleEntryImpl::WriteData(int stream_index, net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_WRITE_END, CreateNetLogReadWriteCompleteCallback(net::ERR_FAILED)); } - RecordWriteResult(WRITE_RESULT_OVER_MAX_SIZE); + RecordWriteResult(cache_type_, WRITE_RESULT_OVER_MAX_SIZE); return net::ERR_FAILED; } ScopedOperationRunner operation_runner(this); - // Currently, Simple Cache is only used for HTTP, which stores the headers in - // stream 0 and always writes them with a single, truncating write. Detect - // these writes and record the size and size changes of the headers. Also, - // note writes to stream 0 that violate those assumptions. - if (stream_index == 0) { - if (offset == 0 && truncate) - RecordHeaderSizeChange(data_size_[0], buf_len); - else - RecordUnexpectedStream0Write(); - } + // Stream 0 data is kept in memory, so can be written immediatly if there are + // no IO operations pending. + if (stream_index == 0 && state_ == STATE_READY && + pending_operations_.size() == 0) + return SetStream0Data(buf, offset, buf_len, truncate); // We can only do optimistic Write if there is no pending operations, so // that we are sure that the next call to RunNextOperationIfNeeded will @@ -508,6 +516,17 @@ SimpleEntryImpl::~SimpleEntryImpl() { net_log_.EndEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY); } +void SimpleEntryImpl::PostClientCallback(const CompletionCallback& callback, + int result) { + if (callback.is_null()) + return; + // Note that the callback is posted rather than directly invoked to avoid + // reentrancy issues. + MessageLoopProxy::current()->PostTask( + FROM_HERE, + base::Bind(&InvokeCallbackIfBackendIsAlive, backend_, callback, result)); +} + void SimpleEntryImpl::MakeUninitialized() { state_ = STATE_UNINITIALIZED; std::memset(crc32s_end_offset_, 0, sizeof(crc32s_end_offset_)); @@ -523,6 +542,15 @@ void SimpleEntryImpl::ReturnEntryToCaller(Entry** out_entry) { DCHECK(out_entry); ++open_count_; AddRef(); // Balanced in Close() + if (!backend_.get()) { + // This method can be called when an asynchronous operation completed. + // If the backend no longer exists, the callback won't be invoked, and so we + // must close ourselves to avoid leaking. As well, there's no guarantee the + // client-provided pointer (|out_entry|) hasn't been freed, and no point + // dereferencing it, either. + Close(); + return; + } *out_entry = this; } @@ -530,20 +558,21 @@ void SimpleEntryImpl::RemoveSelfFromBackend() { if (!backend_.get()) return; backend_->OnDeactivated(this); - backend_.reset(); } void SimpleEntryImpl::MarkAsDoomed() { + doomed_ = true; if (!backend_.get()) return; - backend_->index()->Remove(key_); + backend_->index()->Remove(entry_hash_); RemoveSelfFromBackend(); } void SimpleEntryImpl::RunNextOperationIfNeeded() { DCHECK(io_thread_checker_.CalledOnValidThread()); - UMA_HISTOGRAM_CUSTOM_COUNTS("SimpleCache.EntryOperationsPending", - pending_operations_.size(), 0, 100, 20); + SIMPLE_CACHE_UMA(CUSTOM_COUNTS, + "EntryOperationsPending", cache_type_, + pending_operations_.size(), 0, 100, 20); if (!pending_operations_.empty() && state_ != STATE_IO_PENDING) { scoped_ptr<SimpleEntryOperation> operation( new SimpleEntryOperation(pending_operations_.front())); @@ -579,6 +608,9 @@ void SimpleEntryImpl::RunNextOperationIfNeeded() { operation->callback(), operation->truncate()); break; + case SimpleEntryOperation::TYPE_DOOM: + DoomEntryInternal(operation->callback()); + break; default: NOTREACHED(); } @@ -599,17 +631,14 @@ void SimpleEntryImpl::OpenEntryInternal(bool have_index, if (state_ == STATE_READY) { ReturnEntryToCaller(out_entry); - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind(callback, - net::OK)); + PostClientCallback(callback, net::OK); net_log_.AddEvent( net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_OPEN_END, CreateNetLogSimpleEntryCreationCallback(this, net::OK)); return; - } else if (state_ == STATE_FAILURE) { - if (!callback.is_null()) { - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - callback, net::ERR_FAILED)); - } + } + if (state_ == STATE_FAILURE) { + PostClientCallback(callback, net::ERR_FAILED); net_log_.AddEvent( net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_OPEN_END, CreateNetLogSimpleEntryCreationCallback(this, net::ERR_FAILED)); @@ -624,6 +653,7 @@ void SimpleEntryImpl::OpenEntryInternal(bool have_index, new SimpleEntryCreationResults( SimpleEntryStat(last_used_, last_modified_, data_size_))); Closure task = base::Bind(&SimpleSynchronousEntry::OpenEntry, + cache_type_, path_, entry_hash_, have_index, @@ -650,11 +680,7 @@ void SimpleEntryImpl::CreateEntryInternal(bool have_index, net_log_.AddEvent( net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_CREATE_END, CreateNetLogSimpleEntryCreationCallback(this, net::ERR_FAILED)); - - if (!callback.is_null()) { - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - callback, net::ERR_FAILED)); - } + PostClientCallback(callback, net::ERR_FAILED); return; } DCHECK_EQ(STATE_UNINITIALIZED, state_); @@ -667,7 +693,7 @@ void SimpleEntryImpl::CreateEntryInternal(bool have_index, last_used_ = last_modified_ = base::Time::Now(); // If creation succeeds, we should mark all streams to be saved on close. - for (int i = 0; i < kSimpleEntryFileCount; ++i) + for (int i = 0; i < kSimpleEntryStreamCount; ++i) have_written_[i] = true; const base::TimeTicks start_time = base::TimeTicks::Now(); @@ -675,6 +701,7 @@ void SimpleEntryImpl::CreateEntryInternal(bool have_index, new SimpleEntryCreationResults( SimpleEntryStat(last_used_, last_modified_, data_size_))); Closure task = base::Bind(&SimpleSynchronousEntry::CreateEntry, + cache_type_, path_, key_, entry_hash_, @@ -701,7 +728,7 @@ void SimpleEntryImpl::CloseInternal() { if (state_ == STATE_READY) { DCHECK(synchronous_entry_); state_ = STATE_IO_PENDING; - for (int i = 0; i < kSimpleEntryFileCount; ++i) { + for (int i = 0; i < kSimpleEntryStreamCount; ++i) { if (have_written_[i]) { if (GetDataSize(i) == crc32s_end_offset_[i]) { int32 crc = GetDataSize(i) == 0 ? crc32(0, Z_NULL, 0) : crc32s_[i]; @@ -720,19 +747,20 @@ void SimpleEntryImpl::CloseInternal() { base::Bind(&SimpleSynchronousEntry::Close, base::Unretained(synchronous_entry_), SimpleEntryStat(last_used_, last_modified_, data_size_), - base::Passed(&crc32s_to_write)); + base::Passed(&crc32s_to_write), + stream_0_data_); Closure reply = base::Bind(&SimpleEntryImpl::CloseOperationComplete, this); synchronous_entry_ = NULL; worker_pool_->PostTaskAndReply(FROM_HERE, task, reply); - for (int i = 0; i < kSimpleEntryFileCount; ++i) { + for (int i = 0; i < kSimpleEntryStreamCount; ++i) { if (!have_written_[i]) { - UMA_HISTOGRAM_ENUMERATION("SimpleCache.CheckCRCResult", - crc_check_state_[i], CRC_CHECK_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "CheckCRCResult", cache_type_, + crc_check_state_[i], CRC_CHECK_MAX); } } } else { - synchronous_entry_ = NULL; CloseOperationComplete(); } } @@ -754,9 +782,12 @@ void SimpleEntryImpl::ReadDataInternal(int stream_index, if (state_ == STATE_FAILURE || state_ == STATE_UNINITIALIZED) { if (!callback.is_null()) { - RecordReadResult(READ_RESULT_BAD_STATE); - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - callback, net::ERR_FAILED)); + RecordReadResult(cache_type_, READ_RESULT_BAD_STATE); + // Note that the API states that client-provided callbacks for entry-level + // (i.e. non-backend) operations (e.g. read, write) are invoked even if + // the backend was already destroyed. + MessageLoopProxy::current()->PostTask( + FROM_HERE, base::Bind(callback, net::ERR_FAILED)); } if (net_log_.IsLoggingAllEvents()) { net_log_.AddEvent( @@ -767,31 +798,41 @@ void SimpleEntryImpl::ReadDataInternal(int stream_index, } DCHECK_EQ(STATE_READY, state_); if (offset >= GetDataSize(stream_index) || offset < 0 || !buf_len) { - RecordReadResult(READ_RESULT_FAST_EMPTY_RETURN); + RecordReadResult(cache_type_, READ_RESULT_FAST_EMPTY_RETURN); // If there is nothing to read, we bail out before setting state_ to // STATE_IO_PENDING. if (!callback.is_null()) - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - callback, 0)); + MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind(callback, 0)); return; } buf_len = std::min(buf_len, GetDataSize(stream_index) - offset); + // Since stream 0 data is kept in memory, it is read immediately. + if (stream_index == 0) { + int ret_value = ReadStream0Data(buf, offset, buf_len); + if (!callback.is_null()) { + MessageLoopProxy::current()->PostTask(FROM_HERE, + base::Bind(callback, ret_value)); + } + return; + } + state_ = STATE_IO_PENDING; - if (backend_.get()) - backend_->index()->UseIfExists(key_); + if (!doomed_ && backend_.get()) + backend_->index()->UseIfExists(entry_hash_); scoped_ptr<uint32> read_crc32(new uint32()); scoped_ptr<int> result(new int()); - scoped_ptr<base::Time> last_used(new base::Time()); + scoped_ptr<SimpleEntryStat> entry_stat( + new SimpleEntryStat(last_used_, last_modified_, data_size_)); Closure task = base::Bind( &SimpleSynchronousEntry::ReadData, base::Unretained(synchronous_entry_), SimpleSynchronousEntry::EntryOperationData(stream_index, offset, buf_len), make_scoped_refptr(buf), read_crc32.get(), - last_used.get(), + entry_stat.get(), result.get()); Closure reply = base::Bind(&SimpleEntryImpl::ReadOperationComplete, this, @@ -799,7 +840,7 @@ void SimpleEntryImpl::ReadDataInternal(int stream_index, offset, callback, base::Passed(&read_crc32), - base::Passed(&last_used), + base::Passed(&entry_stat), base::Passed(&result)); worker_pool_->PostTaskAndReply(FROM_HERE, task, reply); } @@ -821,41 +862,49 @@ void SimpleEntryImpl::WriteDataInternal(int stream_index, } if (state_ == STATE_FAILURE || state_ == STATE_UNINITIALIZED) { - RecordWriteResult(WRITE_RESULT_BAD_STATE); + RecordWriteResult(cache_type_, WRITE_RESULT_BAD_STATE); if (net_log_.IsLoggingAllEvents()) { net_log_.AddEvent( net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_WRITE_END, CreateNetLogReadWriteCompleteCallback(net::ERR_FAILED)); } if (!callback.is_null()) { - // We need to posttask so that we don't go in a loop when we call the - // callback directly. - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - callback, net::ERR_FAILED)); + MessageLoopProxy::current()->PostTask( + FROM_HERE, base::Bind(callback, net::ERR_FAILED)); } // |this| may be destroyed after return here. return; } DCHECK_EQ(STATE_READY, state_); - state_ = STATE_IO_PENDING; - if (backend_.get()) - backend_->index()->UseIfExists(key_); - // It is easy to incrementally compute the CRC from [0 .. |offset + buf_len|) - // if |offset == 0| or we have already computed the CRC for [0 .. offset). - // We rely on most write operations being sequential, start to end to compute - // the crc of the data. When we write to an entry and close without having - // done a sequential write, we don't check the CRC on read. - if (offset == 0 || crc32s_end_offset_[stream_index] == offset) { - uint32 initial_crc = (offset != 0) ? crc32s_[stream_index] - : crc32(0, Z_NULL, 0); - if (buf_len > 0) { - crc32s_[stream_index] = crc32(initial_crc, - reinterpret_cast<const Bytef*>(buf->data()), - buf_len); + + // Since stream 0 data is kept in memory, it will be written immediatly. + if (stream_index == 0) { + int ret_value = SetStream0Data(buf, offset, buf_len, truncate); + if (!callback.is_null()) { + MessageLoopProxy::current()->PostTask(FROM_HERE, + base::Bind(callback, ret_value)); + } + return; + } + + // Ignore zero-length writes that do not change the file size. + if (buf_len == 0) { + int32 data_size = data_size_[stream_index]; + if (truncate ? (offset == data_size) : (offset <= data_size)) { + RecordWriteResult(cache_type_, WRITE_RESULT_FAST_EMPTY_RETURN); + if (!callback.is_null()) { + MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( + callback, 0)); + } + return; } - crc32s_end_offset_[stream_index] = offset + buf_len; } + state_ = STATE_IO_PENDING; + if (!doomed_ && backend_.get()) + backend_->index()->UseIfExists(entry_hash_); + + AdvanceCrc(buf, offset, buf_len, stream_index); // |entry_stat| needs to be initialized before modifying |data_size_|. scoped_ptr<SimpleEntryStat> entry_stat( @@ -872,12 +921,17 @@ void SimpleEntryImpl::WriteDataInternal(int stream_index, last_used_ = last_modified_ = base::Time::Now(); have_written_[stream_index] = true; + // Writing on stream 1 affects the placement of stream 0 in the file, the EOF + // record will have to be rewritten. + if (stream_index == 1) + have_written_[0] = true; scoped_ptr<int> result(new int()); Closure task = base::Bind(&SimpleSynchronousEntry::WriteData, base::Unretained(synchronous_entry_), SimpleSynchronousEntry::EntryOperationData( - stream_index, offset, buf_len, truncate), + stream_index, offset, buf_len, truncate, + doomed_), make_scoped_refptr(buf), entry_stat.get(), result.get()); @@ -890,6 +944,15 @@ void SimpleEntryImpl::WriteDataInternal(int stream_index, worker_pool_->PostTaskAndReply(FROM_HERE, task, reply); } +void SimpleEntryImpl::DoomEntryInternal(const CompletionCallback& callback) { + PostTaskAndReplyWithResult( + worker_pool_, FROM_HERE, + base::Bind(&SimpleSynchronousEntry::DoomEntry, path_, entry_hash_), + base::Bind(&SimpleEntryImpl::DoomOperationComplete, this, callback, + state_)); + state_ = STATE_IO_PENDING; +} + void SimpleEntryImpl::CreationOperationComplete( const CompletionCallback& completion_callback, const base::TimeTicks& start_time, @@ -900,18 +963,15 @@ void SimpleEntryImpl::CreationOperationComplete( DCHECK_EQ(state_, STATE_IO_PENDING); DCHECK(in_results); ScopedOperationRunner operation_runner(this); - UMA_HISTOGRAM_BOOLEAN( - "SimpleCache.EntryCreationResult", in_results->result == net::OK); + SIMPLE_CACHE_UMA(BOOLEAN, + "EntryCreationResult", cache_type_, + in_results->result == net::OK); if (in_results->result != net::OK) { if (in_results->result != net::ERR_FILE_EXISTS) MarkAsDoomed(); net_log_.AddEventWithNetErrorCode(end_event_type, net::ERR_FAILED); - - if (!completion_callback.is_null()) { - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - completion_callback, net::ERR_FAILED)); - } + PostClientCallback(completion_callback, net::ERR_FAILED); MakeUninitialized(); return; } @@ -922,6 +982,13 @@ void SimpleEntryImpl::CreationOperationComplete( state_ = STATE_READY; synchronous_entry_ = in_results->sync_entry; + if (in_results->stream_0_data) { + stream_0_data_ = in_results->stream_0_data; + // The crc was read in SimpleSynchronousEntry. + crc_check_state_[0] = CRC_CHECK_DONE; + crc32s_[0] = in_results->stream_0_crc32; + crc32s_end_offset_[0] = in_results->entry_stat.data_size(0); + } if (key_.empty()) { SetKey(synchronous_entry_->key()); } else { @@ -930,15 +997,13 @@ void SimpleEntryImpl::CreationOperationComplete( DCHECK_EQ(key_, synchronous_entry_->key()); } UpdateDataFromEntryStat(in_results->entry_stat); - UMA_HISTOGRAM_TIMES("SimpleCache.EntryCreationTime", - (base::TimeTicks::Now() - start_time)); - AdjustOpenEntryCountBy(1); + SIMPLE_CACHE_UMA(TIMES, + "EntryCreationTime", cache_type_, + (base::TimeTicks::Now() - start_time)); + AdjustOpenEntryCountBy(cache_type_, 1); net_log_.AddEvent(end_event_type); - if (!completion_callback.is_null()) { - MessageLoopProxy::current()->PostTask(FROM_HERE, base::Bind( - completion_callback, net::OK)); - } + PostClientCallback(completion_callback, net::OK); } void SimpleEntryImpl::EntryOperationComplete( @@ -971,7 +1036,7 @@ void SimpleEntryImpl::ReadOperationComplete( int offset, const CompletionCallback& completion_callback, scoped_ptr<uint32> read_crc32, - scoped_ptr<base::Time> last_used, + scoped_ptr<SimpleEntryStat> entry_stat, scoped_ptr<int> result) { DCHECK(io_thread_checker_.CalledOnValidThread()); DCHECK(synchronous_entry_); @@ -1007,7 +1072,7 @@ void SimpleEntryImpl::ReadOperationComplete( Closure task = base::Bind(&SimpleSynchronousEntry::CheckEOFRecord, base::Unretained(synchronous_entry_), stream_index, - data_size_[stream_index], + *entry_stat, crc32s_[stream_index], new_result.get()); Closure reply = base::Bind(&SimpleEntryImpl::ChecksumOperationComplete, @@ -1021,9 +1086,9 @@ void SimpleEntryImpl::ReadOperationComplete( } if (*result < 0) { - RecordReadResult(READ_RESULT_SYNC_READ_FAILURE); + RecordReadResult(cache_type_, READ_RESULT_SYNC_READ_FAILURE); } else { - RecordReadResult(READ_RESULT_SUCCESS); + RecordReadResult(cache_type_, READ_RESULT_SUCCESS); if (crc_check_state_[stream_index] == CRC_CHECK_NEVER_READ_TO_END && offset + *result == GetDataSize(stream_index)) { crc_check_state_[stream_index] = CRC_CHECK_NOT_DONE; @@ -1036,10 +1101,7 @@ void SimpleEntryImpl::ReadOperationComplete( } EntryOperationComplete( - stream_index, - completion_callback, - SimpleEntryStat(*last_used, last_modified_, data_size_), - result.Pass()); + stream_index, completion_callback, *entry_stat, result.Pass()); } void SimpleEntryImpl::WriteOperationComplete( @@ -1048,9 +1110,9 @@ void SimpleEntryImpl::WriteOperationComplete( scoped_ptr<SimpleEntryStat> entry_stat, scoped_ptr<int> result) { if (*result >= 0) - RecordWriteResult(WRITE_RESULT_SUCCESS); + RecordWriteResult(cache_type_, WRITE_RESULT_SUCCESS); else - RecordWriteResult(WRITE_RESULT_SYNC_WRITE_FAILURE); + RecordWriteResult(cache_type_, WRITE_RESULT_SYNC_WRITE_FAILURE); if (net_log_.IsLoggingAllEvents()) { net_log_.AddEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_WRITE_END, CreateNetLogReadWriteCompleteCallback(*result)); @@ -1060,6 +1122,18 @@ void SimpleEntryImpl::WriteOperationComplete( stream_index, completion_callback, *entry_stat, result.Pass()); } +void SimpleEntryImpl::DoomOperationComplete( + const CompletionCallback& callback, + State state_to_restore, + int result) { + state_ = state_to_restore; + if (!callback.is_null()) + callback.Run(result); + RunNextOperationIfNeeded(); + if (backend_) + backend_->OnDoomComplete(entry_hash_); +} + void SimpleEntryImpl::ChecksumOperationComplete( int orig_result, int stream_index, @@ -1079,11 +1153,11 @@ void SimpleEntryImpl::ChecksumOperationComplete( if (*result == net::OK) { *result = orig_result; if (orig_result >= 0) - RecordReadResult(READ_RESULT_SUCCESS); + RecordReadResult(cache_type_, READ_RESULT_SUCCESS); else - RecordReadResult(READ_RESULT_SYNC_READ_FAILURE); + RecordReadResult(cache_type_, READ_RESULT_SYNC_READ_FAILURE); } else { - RecordReadResult(READ_RESULT_SYNC_CHECKSUM_FAILURE); + RecordReadResult(cache_type_, READ_RESULT_SYNC_CHECKSUM_FAILURE); } if (net_log_.IsLoggingAllEvents()) { net_log_.AddEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_READ_END, @@ -1103,7 +1177,7 @@ void SimpleEntryImpl::CloseOperationComplete() { DCHECK(STATE_IO_PENDING == state_ || STATE_FAILURE == state_ || STATE_UNINITIALIZED == state_); net_log_.AddEvent(net::NetLog::TYPE_SIMPLE_CACHE_ENTRY_CLOSE_END); - AdjustOpenEntryCountBy(-1); + AdjustOpenEntryCountBy(cache_type_, -1); MakeUninitialized(); RunNextOperationIfNeeded(); } @@ -1114,18 +1188,18 @@ void SimpleEntryImpl::UpdateDataFromEntryStat( DCHECK(synchronous_entry_); DCHECK_EQ(STATE_READY, state_); - last_used_ = entry_stat.last_used; - last_modified_ = entry_stat.last_modified; - for (int i = 0; i < kSimpleEntryFileCount; ++i) { - data_size_[i] = entry_stat.data_size[i]; + last_used_ = entry_stat.last_used(); + last_modified_ = entry_stat.last_modified(); + for (int i = 0; i < kSimpleEntryStreamCount; ++i) { + data_size_[i] = entry_stat.data_size(i); } - if (backend_.get()) - backend_->index()->UpdateEntrySize(key_, GetDiskUsage()); + if (!doomed_ && backend_.get()) + backend_->index()->UpdateEntrySize(entry_hash_, GetDiskUsage()); } int64 SimpleEntryImpl::GetDiskUsage() const { int64 file_size = 0; - for (int i = 0; i < kSimpleEntryFileCount; ++i) { + for (int i = 0; i < kSimpleEntryStreamCount; ++i) { file_size += simple_util::GetFileSizeFromKeyAndDataSize(key_, data_size_[i]); } @@ -1136,13 +1210,31 @@ void SimpleEntryImpl::RecordReadIsParallelizable( const SimpleEntryOperation& operation) const { if (!executing_operation_) return; - // TODO(clamy): The values of this histogram should be changed to something - // more useful. - bool parallelizable_read = - !operation.alone_in_queue() && - executing_operation_->type() == SimpleEntryOperation::TYPE_READ; - UMA_HISTOGRAM_BOOLEAN("SimpleCache.ReadIsParallelizable", - parallelizable_read); + // Used in histograms, please only add entries at the end. + enum ReadDependencyType { + // READ_STANDALONE = 0, Deprecated. + READ_FOLLOWS_READ = 1, + READ_FOLLOWS_CONFLICTING_WRITE = 2, + READ_FOLLOWS_NON_CONFLICTING_WRITE = 3, + READ_FOLLOWS_OTHER = 4, + READ_ALONE_IN_QUEUE = 5, + READ_DEPENDENCY_TYPE_MAX = 6, + }; + + ReadDependencyType type = READ_FOLLOWS_OTHER; + if (operation.alone_in_queue()) { + type = READ_ALONE_IN_QUEUE; + } else if (executing_operation_->type() == SimpleEntryOperation::TYPE_READ) { + type = READ_FOLLOWS_READ; + } else if (executing_operation_->type() == SimpleEntryOperation::TYPE_WRITE) { + if (executing_operation_->ConflictsWith(operation)) + type = READ_FOLLOWS_CONFLICTING_WRITE; + else + type = READ_FOLLOWS_NON_CONFLICTING_WRITE; + } + SIMPLE_CACHE_UMA(ENUMERATION, + "ReadIsParallelizable", cache_type_, + type, READ_DEPENDENCY_TYPE_MAX); } void SimpleEntryImpl::RecordWriteDependencyType( @@ -1180,8 +1272,85 @@ void SimpleEntryImpl::RecordWriteDependencyType( : WRITE_FOLLOWS_NON_CONFLICTING_WRITE; } } - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.WriteDependencyType", type, WRITE_DEPENDENCY_TYPE_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "WriteDependencyType", cache_type_, + type, WRITE_DEPENDENCY_TYPE_MAX); +} + +int SimpleEntryImpl::ReadStream0Data(net::IOBuffer* buf, + int offset, + int buf_len) { + if (buf_len < 0) { + RecordReadResult(cache_type_, READ_RESULT_SYNC_READ_FAILURE); + return 0; + } + memcpy(buf->data(), stream_0_data_->data() + offset, buf_len); + UpdateDataFromEntryStat( + SimpleEntryStat(base::Time::Now(), last_modified_, data_size_)); + RecordReadResult(cache_type_, READ_RESULT_SUCCESS); + return buf_len; +} + +int SimpleEntryImpl::SetStream0Data(net::IOBuffer* buf, + int offset, + int buf_len, + bool truncate) { + // Currently, stream 0 is only used for HTTP headers, and always writes them + // with a single, truncating write. Detect these writes and record the size + // changes of the headers. Also, support writes to stream 0 that have + // different access patterns, as required by the API contract. + // All other clients of the Simple Cache are encouraged to use stream 1. + have_written_[0] = true; + int data_size = GetDataSize(0); + if (offset == 0 && truncate) { + RecordHeaderSizeChange(cache_type_, data_size, buf_len); + stream_0_data_->SetCapacity(buf_len); + memcpy(stream_0_data_->data(), buf->data(), buf_len); + data_size_[0] = buf_len; + } else { + RecordUnexpectedStream0Write(cache_type_); + const int buffer_size = + truncate ? offset + buf_len : std::max(offset + buf_len, data_size); + stream_0_data_->SetCapacity(buffer_size); + // If |stream_0_data_| was extended, the extension until offset needs to be + // zero-filled. + const int fill_size = offset <= data_size ? 0 : offset - data_size; + if (fill_size > 0) + memset(stream_0_data_->data() + data_size, 0, fill_size); + if (buf) + memcpy(stream_0_data_->data() + offset, buf->data(), buf_len); + data_size_[0] = buffer_size; + } + base::Time modification_time = base::Time::Now(); + AdvanceCrc(buf, offset, buf_len, 0); + UpdateDataFromEntryStat( + SimpleEntryStat(modification_time, modification_time, data_size_)); + RecordWriteResult(cache_type_, WRITE_RESULT_SUCCESS); + return buf_len; +} + +void SimpleEntryImpl::AdvanceCrc(net::IOBuffer* buffer, + int offset, + int length, + int stream_index) { + // It is easy to incrementally compute the CRC from [0 .. |offset + buf_len|) + // if |offset == 0| or we have already computed the CRC for [0 .. offset). + // We rely on most write operations being sequential, start to end to compute + // the crc of the data. When we write to an entry and close without having + // done a sequential write, we don't check the CRC on read. + if (offset == 0 || crc32s_end_offset_[stream_index] == offset) { + uint32 initial_crc = + (offset != 0) ? crc32s_[stream_index] : crc32(0, Z_NULL, 0); + if (length > 0) { + crc32s_[stream_index] = crc32( + initial_crc, reinterpret_cast<const Bytef*>(buffer->data()), length); + } + crc32s_end_offset_[stream_index] = offset + length; + } else if (offset < crc32s_end_offset_[stream_index]) { + // If a range for which the crc32 was already computed is rewritten, the + // computation of the crc32 need to start from 0 again. + crc32s_end_offset_[stream_index] = 0; + } } } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_entry_impl.h b/chromium/net/disk_cache/simple/simple_entry_impl.h index 7eb8914e873..e2f0c63b39e 100644 --- a/chromium/net/disk_cache/simple/simple_entry_impl.h +++ b/chromium/net/disk_cache/simple/simple_entry_impl.h @@ -13,6 +13,8 @@ #include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" #include "base/threading/thread_checker.h" +#include "net/base/cache_type.h" +#include "net/base/net_export.h" #include "net/base/net_log.h" #include "net/disk_cache/disk_cache.h" #include "net/disk_cache/simple/simple_entry_format.h" @@ -23,6 +25,7 @@ class TaskRunner; } namespace net { +class GrowableIOBuffer; class IOBuffer; } @@ -30,13 +33,14 @@ namespace disk_cache { class SimpleBackendImpl; class SimpleSynchronousEntry; -struct SimpleEntryStat; +class SimpleEntryStat; struct SimpleEntryCreationResults; // SimpleEntryImpl is the IO thread interface to an entry in the very simple // disk cache. It proxies for the SimpleSynchronousEntry, which performs IO // on the worker thread. -class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, +class NET_EXPORT_PRIVATE SimpleEntryImpl : public Entry, + public base::RefCounted<SimpleEntryImpl>, public base::SupportsWeakPtr<SimpleEntryImpl> { friend class base::RefCounted<SimpleEntryImpl>; public: @@ -45,7 +49,8 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, OPTIMISTIC_OPERATIONS, }; - SimpleEntryImpl(const base::FilePath& path, + SimpleEntryImpl(net::CacheType cache_type, + const base::FilePath& path, uint64 entry_hash, OperationsMode operations_mode, SimpleBackendImpl* backend, @@ -132,6 +137,12 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, virtual ~SimpleEntryImpl(); + // Must be used to invoke a client-provided completion callback for an + // operation initiated through the backend (e.g. create, open) so that clients + // don't get notified after they deleted the backend (which they would not + // expect). + void PostClientCallback(const CompletionCallback& callback, int result); + // Sets entry to STATE_UNINITIALIZED. void MakeUninitialized(); @@ -177,6 +188,8 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, const CompletionCallback& callback, bool truncate); + void DoomEntryInternal(const CompletionCallback& callback); + // Called after a SimpleSynchronousEntry has completed CreateEntry() or // OpenEntry(). If |in_sync_entry| is non-NULL, creation is successful and we // can return |this| SimpleEntryImpl to |*out_entry|. Runs @@ -205,7 +218,7 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, int offset, const CompletionCallback& completion_callback, scoped_ptr<uint32> read_crc32, - scoped_ptr<base::Time> last_used, + scoped_ptr<SimpleEntryStat> entry_stat, scoped_ptr<int> result); // Called after an asynchronous write completes. @@ -214,6 +227,11 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, scoped_ptr<SimpleEntryStat> entry_stat, scoped_ptr<int> result); + // Called after an asynchronous doom completes. + void DoomOperationComplete(const CompletionCallback& callback, + State state_to_restore, + int result); + // Called after validating the checksums on an entry. Passes through the // original result if successful, propogates the error if the checksum does // not validate. @@ -234,11 +252,29 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, void RecordReadIsParallelizable(const SimpleEntryOperation& operation) const; void RecordWriteDependencyType(const SimpleEntryOperation& operation) const; + // Reads from the stream 0 data kept in memory. + int ReadStream0Data(net::IOBuffer* buf, int offset, int buf_len); + + // Copies data from |buf| to the internal in-memory buffer for stream 0. If + // |truncate| is set to true, the target buffer will be truncated at |offset| + // + |buf_len| before being written. + int SetStream0Data(net::IOBuffer* buf, + int offset, int buf_len, + bool truncate); + + // Updates |crc32s_| and |crc32s_end_offset_| for a write of the data in + // |buffer| on |stream_index|, starting at |offset| and of length |length|. + void AdvanceCrc(net::IOBuffer* buffer, + int offset, + int length, + int stream_index); + // All nonstatic SimpleEntryImpl methods should always be called on the IO // thread, in all cases. |io_thread_checker_| documents and enforces this. base::ThreadChecker io_thread_checker_; - base::WeakPtr<SimpleBackendImpl> backend_; + const base::WeakPtr<SimpleBackendImpl> backend_; + const net::CacheType cache_type_; const scoped_refptr<base::TaskRunner> worker_pool_; const base::FilePath path_; const uint64 entry_hash_; @@ -250,32 +286,38 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, // TODO(clamy): Unify last_used_ with data in the index. base::Time last_used_; base::Time last_modified_; - int32 data_size_[kSimpleEntryFileCount]; + int32 data_size_[kSimpleEntryStreamCount]; // Number of times this object has been returned from Backend::OpenEntry() and // Backend::CreateEntry() without subsequent Entry::Close() calls. Used to // notify the backend when this entry not used by any callers. int open_count_; + bool doomed_; + State state_; // When possible, we compute a crc32, for the data in each entry as we read or // write. For each stream, |crc32s_[index]| is the crc32 of that stream from // [0 .. |crc32s_end_offset_|). If |crc32s_end_offset_[index] == 0| then the // value of |crc32s_[index]| is undefined. - int32 crc32s_end_offset_[kSimpleEntryFileCount]; - uint32 crc32s_[kSimpleEntryFileCount]; + int32 crc32s_end_offset_[kSimpleEntryStreamCount]; + uint32 crc32s_[kSimpleEntryStreamCount]; - // If |have_written_[index]| is true, we have written to the stream |index|. - bool have_written_[kSimpleEntryFileCount]; + // If |have_written_[index]| is true, we have written to the file that + // contains stream |index|. + bool have_written_[kSimpleEntryStreamCount]; // Reflects how much CRC checking has been done with the entry. This state is // reported on closing each entry stream. - CheckCrcResult crc_check_state_[kSimpleEntryFileCount]; + CheckCrcResult crc_check_state_[kSimpleEntryStreamCount]; // The |synchronous_entry_| is the worker thread object that performs IO on - // entries. It's owned by this SimpleEntryImpl whenever |operation_running_| - // is false (i.e. when an operation is not pending on the worker pool). + // entries. It's owned by this SimpleEntryImpl whenever |executing_operation_| + // is false (i.e. when an operation is not pending on the worker pool). When + // an operation is being executed no one owns the synchronous entry. Therefore + // SimpleEntryImpl should not be deleted while an operation is running as that + // would leak the SimpleSynchronousEntry. SimpleSynchronousEntry* synchronous_entry_; std::queue<SimpleEntryOperation> pending_operations_; @@ -283,6 +325,17 @@ class SimpleEntryImpl : public Entry, public base::RefCounted<SimpleEntryImpl>, net::BoundNetLog net_log_; scoped_ptr<SimpleEntryOperation> executing_operation_; + + // Unlike other streams, stream 0 data is read from the disk when the entry is + // opened, and then kept in memory. All read/write operations on stream 0 + // affect the |stream_0_data_| buffer. When the entry is closed, + // |stream_0_data_| is written to the disk. + // Stream 0 is kept in memory because it is stored in the same file as stream + // 1 on disk, to reduce the number of file descriptors and save disk space. + // This strategy allows stream 1 to change size easily. Since stream 0 is only + // used to write HTTP headers, the memory consumption of keeping it in memory + // is acceptable. + scoped_refptr<net::GrowableIOBuffer> stream_0_data_; }; } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_entry_operation.cc b/chromium/net/disk_cache/simple/simple_entry_operation.cc index 81d5f7c888b..d4e76082084 100644 --- a/chromium/net/disk_cache/simple/simple_entry_operation.cc +++ b/chromium/net/disk_cache/simple/simple_entry_operation.cc @@ -28,7 +28,7 @@ SimpleEntryOperation::SimpleEntryOperation(const SimpleEntryOperation& other) SimpleEntryOperation::~SimpleEntryOperation() {} -// Static. +// static SimpleEntryOperation SimpleEntryOperation::OpenOperation( SimpleEntryImpl* entry, bool have_index, @@ -48,7 +48,7 @@ SimpleEntryOperation SimpleEntryOperation::OpenOperation( false); } -// Static. +// static SimpleEntryOperation SimpleEntryOperation::CreateOperation( SimpleEntryImpl* entry, bool have_index, @@ -68,7 +68,7 @@ SimpleEntryOperation SimpleEntryOperation::CreateOperation( false); } -// Static. +// static SimpleEntryOperation SimpleEntryOperation::CloseOperation( SimpleEntryImpl* entry) { return SimpleEntryOperation(entry, @@ -85,7 +85,7 @@ SimpleEntryOperation SimpleEntryOperation::CloseOperation( false); } -// Static. +// static SimpleEntryOperation SimpleEntryOperation::ReadOperation( SimpleEntryImpl* entry, int index, @@ -108,7 +108,7 @@ SimpleEntryOperation SimpleEntryOperation::ReadOperation( alone_in_queue); } -// Static. +// static SimpleEntryOperation SimpleEntryOperation::WriteOperation( SimpleEntryImpl* entry, int index, @@ -132,6 +132,33 @@ SimpleEntryOperation SimpleEntryOperation::WriteOperation( false); } +// static +SimpleEntryOperation SimpleEntryOperation::DoomOperation( + SimpleEntryImpl* entry, + const CompletionCallback& callback) { + net::IOBuffer* const buf = NULL; + Entry** const out_entry = NULL; + const int offset = 0; + const int length = 0; + const bool have_index = false; + const int index = 0; + const bool truncate = false; + const bool optimistic = false; + const bool alone_in_queue = false; + return SimpleEntryOperation(entry, + buf, + callback, + out_entry, + offset, + length, + TYPE_DOOM, + have_index, + index, + truncate, + optimistic, + alone_in_queue); +} + bool SimpleEntryOperation::ConflictsWith( const SimpleEntryOperation& other_op) const { if (type_ != TYPE_READ && type_ != TYPE_WRITE) diff --git a/chromium/net/disk_cache/simple/simple_entry_operation.h b/chromium/net/disk_cache/simple/simple_entry_operation.h index acdd60a3207..1c787017229 100644 --- a/chromium/net/disk_cache/simple/simple_entry_operation.h +++ b/chromium/net/disk_cache/simple/simple_entry_operation.h @@ -31,6 +31,7 @@ class SimpleEntryOperation { TYPE_CLOSE = 2, TYPE_READ = 3, TYPE_WRITE = 4, + TYPE_DOOM = 5, }; SimpleEntryOperation(const SimpleEntryOperation& other); @@ -63,6 +64,10 @@ class SimpleEntryOperation { bool optimistic, const CompletionCallback& callback); + static SimpleEntryOperation DoomOperation( + SimpleEntryImpl* entry, + const CompletionCallback& callback); + bool ConflictsWith(const SimpleEntryOperation& other_op) const; // Releases all references. After calling this operation, SimpleEntryOperation // will only hold POD members. diff --git a/chromium/net/disk_cache/simple/simple_histogram_macros.h b/chromium/net/disk_cache/simple/simple_histogram_macros.h new file mode 100644 index 00000000000..2107ad466a7 --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_histogram_macros.h @@ -0,0 +1,35 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DISK_CACHE_SIMPLE_SIMPLE_HISTOGRAM_MACROS_H_ +#define NET_DISK_CACHE_SIMPLE_SIMPLE_HISTOGRAM_MACROS_H_ + +#include "base/metrics/histogram.h" +#include "base/metrics/sparse_histogram.h" +#include "net/base/cache_type.h" + +// This file contains macros used to report histograms. The main issue is that +// we want to have separate histograms for each type of cache (http vs app), +// while making it easy to report histograms and have all names precomputed. + +#define SIMPLE_CACHE_THUNK(uma_type, args) UMA_HISTOGRAM_##uma_type args + +#define SIMPLE_CACHE_UMA(uma_type, uma_name, cache_type, ...) \ + do { \ + switch (cache_type) { \ + case net::DISK_CACHE: \ + SIMPLE_CACHE_THUNK( \ + uma_type, ("SimpleCache.Http." uma_name, ##__VA_ARGS__)); \ + break; \ + case net::APP_CACHE: \ + SIMPLE_CACHE_THUNK( \ + uma_type, ("SimpleCache.App." uma_name, ##__VA_ARGS__)); \ + break; \ + default: \ + NOTREACHED(); \ + break; \ + } \ + } while (0) + +#endif // NET_DISK_CACHE_SIMPLE_SIMPLE_HISTOGRAM_MACROS_H_ diff --git a/chromium/net/disk_cache/simple/simple_index.cc b/chromium/net/disk_cache/simple/simple_index.cc index 78ce87ed71d..dafc0e72177 100644 --- a/chromium/net/disk_cache/simple/simple_index.cc +++ b/chromium/net/disk_cache/simple/simple_index.cc @@ -4,6 +4,9 @@ #include "net/disk_cache/simple/simple_index.h" +#include <algorithm> +#include <limits> +#include <string> #include <utility> #include "base/bind.h" @@ -13,7 +16,6 @@ #include "base/logging.h" #include "base/message_loop/message_loop.h" #include "base/metrics/field_trial.h" -#include "base/metrics/histogram.h" #include "base/pickle.h" #include "base/strings/string_number_conversions.h" #include "base/strings/string_tokenizer.h" @@ -22,6 +24,8 @@ #include "base/time/time.h" #include "net/base/net_errors.h" #include "net/disk_cache/simple/simple_entry_format.h" +#include "net/disk_cache/simple/simple_histogram_macros.h" +#include "net/disk_cache/simple/simple_index_delegate.h" #include "net/disk_cache/simple/simple_index_file.h" #include "net/disk_cache/simple/simple_synchronous_entry.h" #include "net/disk_cache/simple/simple_util.h" @@ -72,43 +76,77 @@ bool CompareHashesForTimestamp::operator()(uint64 hash1, uint64 hash2) { namespace disk_cache { -EntryMetadata::EntryMetadata() : last_used_time_(0), entry_size_(0) {} +EntryMetadata::EntryMetadata() + : last_used_time_seconds_since_epoch_(0), + entry_size_(0) { +} -EntryMetadata::EntryMetadata(base::Time last_used_time, uint64 entry_size) - : last_used_time_(last_used_time.ToInternalValue()), - entry_size_(entry_size) {} +EntryMetadata::EntryMetadata(base::Time last_used_time, int entry_size) + : last_used_time_seconds_since_epoch_(0), + entry_size_(entry_size) { + SetLastUsedTime(last_used_time); +} base::Time EntryMetadata::GetLastUsedTime() const { - return base::Time::FromInternalValue(last_used_time_); + // Preserve nullity. + if (last_used_time_seconds_since_epoch_ == 0) + return base::Time(); + + return base::Time::UnixEpoch() + + base::TimeDelta::FromSeconds(last_used_time_seconds_since_epoch_); } void EntryMetadata::SetLastUsedTime(const base::Time& last_used_time) { - last_used_time_ = last_used_time.ToInternalValue(); + // Preserve nullity. + if (last_used_time.is_null()) { + last_used_time_seconds_since_epoch_ = 0; + return; + } + + const base::TimeDelta since_unix_epoch = + last_used_time - base::Time::UnixEpoch(); + const int64 seconds_since_unix_epoch = since_unix_epoch.InSeconds(); + DCHECK_LE(implicit_cast<int64>(std::numeric_limits<uint32>::min()), + seconds_since_unix_epoch); + DCHECK_GE(implicit_cast<int64>(std::numeric_limits<uint32>::max()), + seconds_since_unix_epoch); + + last_used_time_seconds_since_epoch_ = seconds_since_unix_epoch; + // Avoid accidental nullity. + if (last_used_time_seconds_since_epoch_ == 0) + last_used_time_seconds_since_epoch_ = 1; } void EntryMetadata::Serialize(Pickle* pickle) const { DCHECK(pickle); - COMPILE_ASSERT(sizeof(EntryMetadata) == (sizeof(int64) + sizeof(uint64)), - EntryMetadata_has_two_member_variables); - pickle->WriteInt64(last_used_time_); + int64 internal_last_used_time = GetLastUsedTime().ToInternalValue(); + pickle->WriteInt64(internal_last_used_time); pickle->WriteUInt64(entry_size_); } bool EntryMetadata::Deserialize(PickleIterator* it) { DCHECK(it); - return it->ReadInt64(&last_used_time_) && it->ReadUInt64(&entry_size_); + int64 tmp_last_used_time; + uint64 tmp_entry_size; + if (!it->ReadInt64(&tmp_last_used_time) || !it->ReadUInt64(&tmp_entry_size)) + return false; + SetLastUsedTime(base::Time::FromInternalValue(tmp_last_used_time)); + entry_size_ = tmp_entry_size; + return true; } SimpleIndex::SimpleIndex(base::SingleThreadTaskRunner* io_thread, - const base::FilePath& cache_directory, + SimpleIndexDelegate* delegate, + net::CacheType cache_type, scoped_ptr<SimpleIndexFile> index_file) - : cache_size_(0), + : delegate_(delegate), + cache_type_(cache_type), + cache_size_(0), max_size_(0), high_watermark_(0), low_watermark_(0), eviction_in_progress_(false), initialized_(false), - cache_directory_(cache_directory), index_file_(index_file.Pass()), io_thread_(io_thread), // Creating the callback once so it is reused every time @@ -148,8 +186,10 @@ void SimpleIndex::Initialize(base::Time cache_mtime) { } #if defined(OS_ANDROID) - activity_status_listener_.reset(new base::android::ActivityStatus::Listener( - base::Bind(&SimpleIndex::OnActivityStateChange, AsWeakPtr()))); + if (base::android::IsVMInitialized()) { + activity_status_listener_.reset(new base::android::ActivityStatus::Listener( + base::Bind(&SimpleIndex::OnActivityStateChange, AsWeakPtr()))); + } #endif SimpleIndexLoadResult* load_result = new SimpleIndexLoadResult(); @@ -184,14 +224,32 @@ int SimpleIndex::ExecuteWhenReady(const net::CompletionCallback& task) { return net::ERR_IO_PENDING; } -scoped_ptr<SimpleIndex::HashList> SimpleIndex::RemoveEntriesBetween( - const base::Time initial_time, const base::Time end_time) { - return ExtractEntriesBetween(initial_time, end_time, true); +scoped_ptr<SimpleIndex::HashList> SimpleIndex::GetEntriesBetween( + base::Time initial_time, base::Time end_time) { + DCHECK_EQ(true, initialized_); + + if (!initial_time.is_null()) + initial_time -= EntryMetadata::GetLowerEpsilonForTimeComparisons(); + if (end_time.is_null()) + end_time = base::Time::Max(); + else + end_time += EntryMetadata::GetUpperEpsilonForTimeComparisons(); + const base::Time extended_end_time = + end_time.is_null() ? base::Time::Max() : end_time; + DCHECK(extended_end_time >= initial_time); + scoped_ptr<HashList> ret_hashes(new HashList()); + for (EntrySet::iterator it = entries_set_.begin(), end = entries_set_.end(); + it != end; ++it) { + EntryMetadata& metadata = it->second; + base::Time entry_time = metadata.GetLastUsedTime(); + if (initial_time <= entry_time && entry_time < extended_end_time) + ret_hashes->push_back(it->first); + } + return ret_hashes.Pass(); } scoped_ptr<SimpleIndex::HashList> SimpleIndex::GetAllHashes() { - const base::Time null_time = base::Time(); - return ExtractEntriesBetween(null_time, null_time, false); + return GetEntriesBetween(base::Time(), base::Time()); } int32 SimpleIndex::GetEntryCount() const { @@ -199,30 +257,28 @@ int32 SimpleIndex::GetEntryCount() const { return entries_set_.size(); } -void SimpleIndex::Insert(const std::string& key) { +void SimpleIndex::Insert(uint64 entry_hash) { DCHECK(io_thread_checker_.CalledOnValidThread()); // Upon insert we don't know yet the size of the entry. // It will be updated later when the SimpleEntryImpl finishes opening or // creating the new entry, and then UpdateEntrySize will be called. - const uint64 hash_key = simple_util::GetEntryHashKey(key); InsertInEntrySet( - hash_key, EntryMetadata(base::Time::Now(), 0), &entries_set_); + entry_hash, EntryMetadata(base::Time::Now(), 0), &entries_set_); if (!initialized_) - removed_entries_.erase(hash_key); + removed_entries_.erase(entry_hash); PostponeWritingToDisk(); } -void SimpleIndex::Remove(const std::string& key) { +void SimpleIndex::Remove(uint64 entry_hash) { DCHECK(io_thread_checker_.CalledOnValidThread()); - const uint64 hash_key = simple_util::GetEntryHashKey(key); - EntrySet::iterator it = entries_set_.find(hash_key); + EntrySet::iterator it = entries_set_.find(entry_hash); if (it != entries_set_.end()) { UpdateEntryIteratorSize(&it, 0); entries_set_.erase(it); } if (!initialized_) - removed_entries_.insert(hash_key); + removed_entries_.insert(entry_hash); PostponeWritingToDisk(); } @@ -232,11 +288,11 @@ bool SimpleIndex::Has(uint64 hash) const { return !initialized_ || entries_set_.count(hash) > 0; } -bool SimpleIndex::UseIfExists(const std::string& key) { +bool SimpleIndex::UseIfExists(uint64 entry_hash) { DCHECK(io_thread_checker_.CalledOnValidThread()); // Always update the last used time, even if it is during initialization. // It will be merged later. - EntrySet::iterator it = entries_set_.find(simple_util::GetEntryHashKey(key)); + EntrySet::iterator it = entries_set_.find(entry_hash); if (it == entries_set_.end()) // If not initialized, always return true, forcing it to go to the disk. return !initialized_; @@ -249,52 +305,54 @@ void SimpleIndex::StartEvictionIfNeeded() { DCHECK(io_thread_checker_.CalledOnValidThread()); if (eviction_in_progress_ || cache_size_ <= high_watermark_) return; - // Take all live key hashes from the index and sort them by time. eviction_in_progress_ = true; eviction_start_time_ = base::TimeTicks::Now(); - UMA_HISTOGRAM_MEMORY_KB("SimpleCache.Eviction.CacheSizeOnStart2", - cache_size_ / kBytesInKb); - UMA_HISTOGRAM_MEMORY_KB("SimpleCache.Eviction.MaxCacheSizeOnStart2", - max_size_ / kBytesInKb); - scoped_ptr<std::vector<uint64> > entry_hashes(new std::vector<uint64>()); + SIMPLE_CACHE_UMA(MEMORY_KB, + "Eviction.CacheSizeOnStart2", cache_type_, + cache_size_ / kBytesInKb); + SIMPLE_CACHE_UMA(MEMORY_KB, + "Eviction.MaxCacheSizeOnStart2", cache_type_, + max_size_ / kBytesInKb); + std::vector<uint64> entry_hashes; + entry_hashes.reserve(entries_set_.size()); for (EntrySet::const_iterator it = entries_set_.begin(), end = entries_set_.end(); it != end; ++it) { - entry_hashes->push_back(it->first); + entry_hashes.push_back(it->first); } - std::sort(entry_hashes->begin(), entry_hashes->end(), + std::sort(entry_hashes.begin(), entry_hashes.end(), CompareHashesForTimestamp(entries_set_)); // Remove as many entries from the index to get below |low_watermark_|. - std::vector<uint64>::iterator it = entry_hashes->begin(); + std::vector<uint64>::iterator it = entry_hashes.begin(); uint64 evicted_so_far_size = 0; while (evicted_so_far_size < cache_size_ - low_watermark_) { - DCHECK(it != entry_hashes->end()); + DCHECK(it != entry_hashes.end()); EntrySet::iterator found_meta = entries_set_.find(*it); DCHECK(found_meta != entries_set_.end()); uint64 to_evict_size = found_meta->second.GetEntrySize(); evicted_so_far_size += to_evict_size; - entries_set_.erase(found_meta); ++it; } - cache_size_ -= evicted_so_far_size; // Take out the rest of hashes from the eviction list. - entry_hashes->erase(it, entry_hashes->end()); - UMA_HISTOGRAM_COUNTS("SimpleCache.Eviction.EntryCount", entry_hashes->size()); - UMA_HISTOGRAM_TIMES("SimpleCache.Eviction.TimeToSelectEntries", - base::TimeTicks::Now() - eviction_start_time_); - UMA_HISTOGRAM_MEMORY_KB("SimpleCache.Eviction.SizeOfEvicted2", - evicted_so_far_size / kBytesInKb); - - index_file_->DoomEntrySet( - entry_hashes.Pass(), - base::Bind(&SimpleIndex::EvictionDone, AsWeakPtr())); + entry_hashes.erase(it, entry_hashes.end()); + SIMPLE_CACHE_UMA(COUNTS, + "Eviction.EntryCount", cache_type_, entry_hashes.size()); + SIMPLE_CACHE_UMA(TIMES, + "Eviction.TimeToSelectEntries", cache_type_, + base::TimeTicks::Now() - eviction_start_time_); + SIMPLE_CACHE_UMA(MEMORY_KB, + "Eviction.SizeOfEvicted2", cache_type_, + evicted_so_far_size / kBytesInKb); + + delegate_->DoomEntries(&entry_hashes, base::Bind(&SimpleIndex::EvictionDone, + AsWeakPtr())); } -bool SimpleIndex::UpdateEntrySize(const std::string& key, uint64 entry_size) { +bool SimpleIndex::UpdateEntrySize(uint64 entry_hash, int entry_size) { DCHECK(io_thread_checker_.CalledOnValidThread()); - EntrySet::iterator it = entries_set_.find(simple_util::GetEntryHashKey(key)); + EntrySet::iterator it = entries_set_.find(entry_hash); if (it == entries_set_.end()) return false; @@ -309,20 +367,22 @@ void SimpleIndex::EvictionDone(int result) { // Ignore the result of eviction. We did our best. eviction_in_progress_ = false; - UMA_HISTOGRAM_BOOLEAN("SimpleCache.Eviction.Result", result == net::OK); - UMA_HISTOGRAM_TIMES("SimpleCache.Eviction.TimeToDone", - base::TimeTicks::Now() - eviction_start_time_); - UMA_HISTOGRAM_MEMORY_KB("SimpleCache.Eviction.SizeWhenDone2", - cache_size_ / kBytesInKb); + SIMPLE_CACHE_UMA(BOOLEAN, "Eviction.Result", cache_type_, result == net::OK); + SIMPLE_CACHE_UMA(TIMES, + "Eviction.TimeToDone", cache_type_, + base::TimeTicks::Now() - eviction_start_time_); + SIMPLE_CACHE_UMA(MEMORY_KB, + "Eviction.SizeWhenDone2", cache_type_, + cache_size_ / kBytesInKb); } // static void SimpleIndex::InsertInEntrySet( - uint64 hash_key, + uint64 entry_hash, const disk_cache::EntryMetadata& entry_metadata, EntrySet* entry_set) { DCHECK(entry_set); - entry_set->insert(std::make_pair(hash_key, entry_metadata)); + entry_set->insert(std::make_pair(entry_hash, entry_metadata)); } void SimpleIndex::PostponeWritingToDisk() { @@ -336,10 +396,10 @@ void SimpleIndex::PostponeWritingToDisk() { } void SimpleIndex::UpdateEntryIteratorSize(EntrySet::iterator* it, - uint64 entry_size) { + int entry_size) { // Update the total cache size with the new entry size. DCHECK(io_thread_checker_.CalledOnValidThread()); - DCHECK_GE(cache_size_, (*it)->second.GetEntrySize()); + DCHECK_GE(cache_size_, implicit_cast<uint64>((*it)->second.GetEntrySize())); cache_size_ -= (*it)->second.GetEntrySize(); cache_size_ += entry_size; (*it)->second.SetEntrySize(entry_size); @@ -350,14 +410,13 @@ void SimpleIndex::MergeInitializingSet( DCHECK(io_thread_checker_.CalledOnValidThread()); DCHECK(load_result->did_load); - SimpleIndex::EntrySet* index_file_entries = &load_result->entries; - // First, remove the entries that are in the |removed_entries_| from both - // sets. - for (base::hash_set<uint64>::const_iterator it = - removed_entries_.begin(); it != removed_entries_.end(); ++it) { - entries_set_.erase(*it); + EntrySet* index_file_entries = &load_result->entries; + + for (base::hash_set<uint64>::const_iterator it = removed_entries_.begin(); + it != removed_entries_.end(); ++it) { index_file_entries->erase(*it); } + removed_entries_.clear(); for (EntrySet::const_iterator it = entries_set_.begin(); it != entries_set_.end(); ++it) { @@ -378,15 +437,15 @@ void SimpleIndex::MergeInitializingSet( entries_set_.swap(*index_file_entries); cache_size_ = merged_cache_size; initialized_ = true; - removed_entries_.clear(); // The actual IO is asynchronous, so calling WriteToDisk() shouldn't slow the // merge down much. if (load_result->flush_required) WriteToDisk(); - UMA_HISTOGRAM_CUSTOM_COUNTS("SimpleCache.IndexInitializationWaiters", - to_run_when_initialized_.size(), 0, 100, 20); + SIMPLE_CACHE_UMA(CUSTOM_COUNTS, + "IndexInitializationWaiters", cache_type_, + to_run_when_initialized_.size(), 0, 100, 20); // Run all callbacks waiting for the index to come up. for (CallbackList::iterator it = to_run_when_initialized_.begin(), end = to_run_when_initialized_.end(); it != end; ++it) { @@ -415,16 +474,19 @@ void SimpleIndex::WriteToDisk() { DCHECK(io_thread_checker_.CalledOnValidThread()); if (!initialized_) return; - UMA_HISTOGRAM_CUSTOM_COUNTS("SimpleCache.IndexNumEntriesOnWrite", - entries_set_.size(), 0, 100000, 50); + SIMPLE_CACHE_UMA(CUSTOM_COUNTS, + "IndexNumEntriesOnWrite", cache_type_, + entries_set_.size(), 0, 100000, 50); const base::TimeTicks start = base::TimeTicks::Now(); if (!last_write_to_disk_.is_null()) { if (app_on_background_) { - UMA_HISTOGRAM_MEDIUM_TIMES("SimpleCache.IndexWriteInterval.Background", - start - last_write_to_disk_); + SIMPLE_CACHE_UMA(MEDIUM_TIMES, + "IndexWriteInterval.Background", cache_type_, + start - last_write_to_disk_); } else { - UMA_HISTOGRAM_MEDIUM_TIMES("SimpleCache.IndexWriteInterval.Foreground", - start - last_write_to_disk_); + SIMPLE_CACHE_UMA(MEDIUM_TIMES, + "IndexWriteInterval.Foreground", cache_type_, + start - last_write_to_disk_); } } last_write_to_disk_ = start; @@ -433,29 +495,4 @@ void SimpleIndex::WriteToDisk() { start, app_on_background_); } -scoped_ptr<SimpleIndex::HashList> SimpleIndex::ExtractEntriesBetween( - const base::Time initial_time, const base::Time end_time, - bool delete_entries) { - DCHECK_EQ(true, initialized_); - const base::Time extended_end_time = - end_time.is_null() ? base::Time::Max() : end_time; - DCHECK(extended_end_time >= initial_time); - scoped_ptr<HashList> ret_hashes(new HashList()); - for (EntrySet::iterator it = entries_set_.begin(), end = entries_set_.end(); - it != end;) { - EntryMetadata& metadata = it->second; - base::Time entry_time = metadata.GetLastUsedTime(); - if (initial_time <= entry_time && entry_time < extended_end_time) { - ret_hashes->push_back(it->first); - if (delete_entries) { - cache_size_ -= metadata.GetEntrySize(); - entries_set_.erase(it++); - continue; - } - } - ++it; - } - return ret_hashes.Pass(); -} - } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_index.h b/chromium/net/disk_cache/simple/simple_index.h index 788ffb2cfe8..6c0d81c99a6 100644 --- a/chromium/net/disk_cache/simple/simple_index.h +++ b/chromium/net/disk_cache/simple/simple_index.h @@ -6,7 +6,6 @@ #define NET_DISK_CACHE_SIMPLE_SIMPLE_INDEX_H_ #include <list> -#include <string> #include <vector> #include "base/basictypes.h" @@ -21,6 +20,7 @@ #include "base/threading/thread_checker.h" #include "base/time/time.h" #include "base/timer/timer.h" +#include "net/base/cache_type.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" @@ -33,38 +33,43 @@ class PickleIterator; namespace disk_cache { +class SimpleIndexDelegate; class SimpleIndexFile; struct SimpleIndexLoadResult; class NET_EXPORT_PRIVATE EntryMetadata { public: EntryMetadata(); - EntryMetadata(base::Time last_used_time, uint64 entry_size); + EntryMetadata(base::Time last_used_time, int entry_size); base::Time GetLastUsedTime() const; void SetLastUsedTime(const base::Time& last_used_time); - uint64 GetEntrySize() const { return entry_size_; } - void SetEntrySize(uint64 entry_size) { entry_size_ = entry_size; } + int GetEntrySize() const { return entry_size_; } + void SetEntrySize(int entry_size) { entry_size_ = entry_size; } // Serialize the data into the provided pickle. void Serialize(Pickle* pickle) const; bool Deserialize(PickleIterator* it); + static base::TimeDelta GetLowerEpsilonForTimeComparisons() { + return base::TimeDelta::FromSeconds(1); + } + static base::TimeDelta GetUpperEpsilonForTimeComparisons() { + return base::TimeDelta(); + } + private: friend class SimpleIndexFileTest; // When adding new members here, you should update the Serialize() and // Deserialize() methods. - // This is the serialized format from Time::ToInternalValue(). - // If you want to make calculations/comparisons, you should use the - // base::Time() class. Use the GetLastUsedTime() method above. - // TODO(felipeg): Use Time() here. - int64 last_used_time_; + uint32 last_used_time_seconds_since_epoch_; - uint64 entry_size_; // Storage size in bytes. + int32 entry_size_; // Storage size in bytes. }; +COMPILE_ASSERT(sizeof(EntryMetadata) == 8, metadata_size); // This class is not Thread-safe. class NET_EXPORT_PRIVATE SimpleIndex @@ -73,7 +78,8 @@ class NET_EXPORT_PRIVATE SimpleIndex typedef std::vector<uint64> HashList; SimpleIndex(base::SingleThreadTaskRunner* io_thread, - const base::FilePath& cache_directory, + SimpleIndexDelegate* delegate, + net::CacheType cache_type, scoped_ptr<SimpleIndexFile> simple_index_file); virtual ~SimpleIndex(); @@ -83,38 +89,38 @@ class NET_EXPORT_PRIVATE SimpleIndex bool SetMaxSize(int max_bytes); int max_size() const { return max_size_; } - void Insert(const std::string& key); - void Remove(const std::string& key); + void Insert(uint64 entry_hash); + void Remove(uint64 entry_hash); // Check whether the index has the entry given the hash of its key. - bool Has(uint64 hash) const; + bool Has(uint64 entry_hash) const; // Update the last used time of the entry with the given key and return true // iff the entry exist in the index. - bool UseIfExists(const std::string& key); + bool UseIfExists(uint64 entry_hash); void WriteToDisk(); // Update the size (in bytes) of an entry, in the metadata stored in the // index. This should be the total disk-file size including all streams of the // entry. - bool UpdateEntrySize(const std::string& key, uint64 entry_size); + bool UpdateEntrySize(uint64 entry_hash, int entry_size); typedef base::hash_map<uint64, EntryMetadata> EntrySet; - static void InsertInEntrySet(uint64 hash_key, + static void InsertInEntrySet(uint64 entry_hash, const EntryMetadata& entry_metadata, EntrySet* entry_set); // Executes the |callback| when the index is ready. Allows multiple callbacks. int ExecuteWhenReady(const net::CompletionCallback& callback); - // Takes out entries from the index that have last accessed time matching the + // Returns entries from the index that have last accessed time matching the // range between |initial_time| and |end_time| where open intervals are // possible according to the definition given in |DoomEntriesBetween()| in the - // disk cache backend interface. Returns the set of hashes taken out. - scoped_ptr<HashList> RemoveEntriesBetween(const base::Time initial_time, - const base::Time end_time); + // disk cache backend interface. + scoped_ptr<HashList> GetEntriesBetween(const base::Time initial_time, + const base::Time end_time); // Returns the list of all entries key hash. scoped_ptr<HashList> GetAllHashes(); @@ -137,7 +143,7 @@ class NET_EXPORT_PRIVATE SimpleIndex void PostponeWritingToDisk(); - void UpdateEntryIteratorSize(EntrySet::iterator* it, uint64 entry_size); + void UpdateEntryIteratorSize(EntrySet::iterator* it, int entry_size); // Must run on IO Thread. void MergeInitializingSet(scoped_ptr<SimpleIndexLoadResult> load_result); @@ -148,12 +154,12 @@ class NET_EXPORT_PRIVATE SimpleIndex scoped_ptr<base::android::ActivityStatus::Listener> activity_status_listener_; #endif - scoped_ptr<HashList> ExtractEntriesBetween(const base::Time initial_time, - const base::Time end_time, - bool delete_entries); + // The owner of |this| must ensure the |delegate_| outlives |this|. + SimpleIndexDelegate* delegate_; EntrySet entries_set_; + const net::CacheType cache_type_; uint64 cache_size_; // Total cache storage size in bytes. uint64 max_size_; uint64 high_watermark_; @@ -161,12 +167,11 @@ class NET_EXPORT_PRIVATE SimpleIndex bool eviction_in_progress_; base::TimeTicks eviction_start_time_; - // This stores all the hash_key of entries that are removed during + // This stores all the entry_hash of entries that are removed during // initialization. base::hash_set<uint64> removed_entries_; bool initialized_; - const base::FilePath& cache_directory_; scoped_ptr<SimpleIndexFile> index_file_; scoped_refptr<base::SingleThreadTaskRunner> io_thread_; diff --git a/chromium/net/disk_cache/simple/simple_index_delegate.h b/chromium/net/disk_cache/simple/simple_index_delegate.h new file mode 100644 index 00000000000..e942484164e --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_index_delegate.h @@ -0,0 +1,28 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DISK_CACHE_SIMPLE_SIMPLE_INDEX_DELEGATE_H_ +#define NET_DISK_CACHE_SIMPLE_SIMPLE_INDEX_DELEGATE_H_ + +#include <vector> + +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" + +namespace disk_cache { + +class NET_EXPORT_PRIVATE SimpleIndexDelegate { + public: + virtual ~SimpleIndexDelegate() {} + + // Dooms all entries in |entries|, calling |callback| with the result + // asynchronously. |entries| is mutated in an undefined way by this call, + // for efficiency. + virtual void DoomEntries(std::vector<uint64>* entry_hashes, + const net::CompletionCallback& callback) = 0; +}; + +} // namespace disk_cache + +#endif // NET_DISK_CACHE_SIMPLE_SIMPLE_INDEX_DELEGATE_H_ diff --git a/chromium/net/disk_cache/simple/simple_index_file.cc b/chromium/net/disk_cache/simple/simple_index_file.cc index 0136be1c5a0..c350cddb39f 100644 --- a/chromium/net/disk_cache/simple/simple_index_file.cc +++ b/chromium/net/disk_cache/simple/simple_index_file.cc @@ -10,12 +10,13 @@ #include "base/files/memory_mapped_file.h" #include "base/hash.h" #include "base/logging.h" -#include "base/metrics/histogram.h" #include "base/pickle.h" #include "base/single_thread_task_runner.h" #include "base/task_runner_util.h" #include "base/threading/thread_restrictions.h" +#include "net/disk_cache/simple/simple_backend_version.h" #include "net/disk_cache/simple/simple_entry_format.h" +#include "net/disk_cache/simple/simple_histogram_macros.h" #include "net/disk_cache/simple/simple_index.h" #include "net/disk_cache/simple/simple_synchronous_entry.h" #include "net/disk_cache/simple/simple_util.h" @@ -38,38 +39,43 @@ uint32 CalculatePickleCRC(const Pickle& pickle) { pickle.payload_size()); } -void DoomEntrySetReply(const net::CompletionCallback& reply_callback, - int result) { - reply_callback.Run(result); +// Used in histograms. Please only add new values at the end. +enum IndexFileState { + INDEX_STATE_CORRUPT = 0, + INDEX_STATE_STALE = 1, + INDEX_STATE_FRESH = 2, + INDEX_STATE_FRESH_CONCURRENT_UPDATES = 3, + INDEX_STATE_MAX = 4, +}; + +void UmaRecordIndexFileState(IndexFileState state, net::CacheType cache_type) { + SIMPLE_CACHE_UMA(ENUMERATION, + "IndexFileStateOnLoad", cache_type, state, INDEX_STATE_MAX); } -void WriteToDiskInternal(const base::FilePath& index_filename, - const base::FilePath& temp_index_filename, - scoped_ptr<Pickle> pickle, - const base::TimeTicks& start_time, - bool app_on_background) { +// Used in histograms. Please only add new values at the end. +enum IndexInitMethod { + INITIALIZE_METHOD_RECOVERED = 0, + INITIALIZE_METHOD_LOADED = 1, + INITIALIZE_METHOD_NEWCACHE = 2, + INITIALIZE_METHOD_MAX = 3, +}; + +void UmaRecordIndexInitMethod(IndexInitMethod method, + net::CacheType cache_type) { + SIMPLE_CACHE_UMA(ENUMERATION, + "IndexInitializeMethod", cache_type, + method, INITIALIZE_METHOD_MAX); +} + +bool WritePickleFile(Pickle* pickle, const base::FilePath& file_name) { int bytes_written = file_util::WriteFile( - temp_index_filename, - reinterpret_cast<const char*>(pickle->data()), - pickle->size()); - DCHECK_EQ(bytes_written, implicit_cast<int>(pickle->size())); - if (bytes_written != static_cast<int>(pickle->size())) { - // TODO(felipeg): Add better error handling. - LOG(ERROR) << "Could not write Simple Cache index to temporary file: " - << temp_index_filename.value(); - base::DeleteFile(temp_index_filename, /* recursive = */ false); - } else { - // Swap temp and index_file. - bool result = base::ReplaceFile(temp_index_filename, index_filename, NULL); - DCHECK(result); - } - if (app_on_background) { - UMA_HISTOGRAM_TIMES("SimpleCache.IndexWriteToDiskTime.Background", - (base::TimeTicks::Now() - start_time)); - } else { - UMA_HISTOGRAM_TIMES("SimpleCache.IndexWriteToDiskTime.Foreground", - (base::TimeTicks::Now() - start_time)); + file_name, static_cast<const char*>(pickle->data()), pickle->size()); + if (bytes_written != implicit_cast<int>(pickle->size())) { + base::DeleteFile(file_name, /* recursive = */ false); + return false; } + return true; } // Called for each cache directory traversal iteration. @@ -114,7 +120,7 @@ void ProcessEntryFile(SimpleIndex::EntrySet* entries, EntryMetadata(last_used_time, file_size), entries); } else { - // Summing up the total size of the entry through all the *_[0-2] files + // Summing up the total size of the entry through all the *_[0-1] files it->second.SetEntrySize(it->second.GetEntrySize() + file_size); } } @@ -134,18 +140,25 @@ void SimpleIndexLoadResult::Reset() { entries.clear(); } -SimpleIndexFile::IndexMetadata::IndexMetadata() : - magic_number_(kSimpleIndexMagicNumber), - version_(kSimpleVersion), - number_of_entries_(0), - cache_size_(0) {} +// static +const char SimpleIndexFile::kIndexFileName[] = "the-real-index"; +// static +const char SimpleIndexFile::kIndexDirectory[] = "index-dir"; +// static +const char SimpleIndexFile::kTempIndexFileName[] = "temp-index"; + +SimpleIndexFile::IndexMetadata::IndexMetadata() + : magic_number_(kSimpleIndexMagicNumber), + version_(kSimpleVersion), + number_of_entries_(0), + cache_size_(0) {} SimpleIndexFile::IndexMetadata::IndexMetadata( - uint64 number_of_entries, uint64 cache_size) : - magic_number_(kSimpleIndexMagicNumber), - version_(kSimpleVersion), - number_of_entries_(number_of_entries), - cache_size_(cache_size) {} + uint64 number_of_entries, uint64 cache_size) + : magic_number_(kSimpleIndexMagicNumber), + version_(kSimpleVersion), + number_of_entries_(number_of_entries), + cache_size_(cache_size) {} void SimpleIndexFile::IndexMetadata::Serialize(Pickle* pickle) const { DCHECK(pickle); @@ -155,6 +168,16 @@ void SimpleIndexFile::IndexMetadata::Serialize(Pickle* pickle) const { pickle->WriteUInt64(cache_size_); } +// static +bool SimpleIndexFile::SerializeFinalData(base::Time cache_modified, + Pickle* pickle) { + if (!pickle->WriteInt64(cache_modified.ToInternalValue())) + return false; + SimpleIndexFile::PickleHeader* header_p = pickle->headerT<PickleHeader>(); + header_p->crc = CalculatePickleCRC(*pickle); + return true; +} + bool SimpleIndexFile::IndexMetadata::Deserialize(PickleIterator* it) { DCHECK(it); return it->ReadUInt64(&magic_number_) && @@ -163,21 +186,70 @@ bool SimpleIndexFile::IndexMetadata::Deserialize(PickleIterator* it) { it->ReadUInt64(&cache_size_); } +void SimpleIndexFile::SyncWriteToDisk(net::CacheType cache_type, + const base::FilePath& cache_directory, + const base::FilePath& index_filename, + const base::FilePath& temp_index_filename, + scoped_ptr<Pickle> pickle, + const base::TimeTicks& start_time, + bool app_on_background) { + // There is a chance that the index containing all the necessary data about + // newly created entries will appear to be stale. This can happen if on-disk + // part of a Create operation does not fit into the time budget for the index + // flush delay. This simple approach will be reconsidered if it does not allow + // for maintaining freshness. + base::PlatformFileInfo cache_dir_info; + base::Time cache_dir_mtime; + if (!simple_util::GetMTime(cache_directory, &cache_dir_mtime)) { + LOG(ERROR) << "Could obtain information about cache age"; + return; + } + SerializeFinalData(cache_dir_mtime, pickle.get()); + if (!WritePickleFile(pickle.get(), temp_index_filename)) { + if (!file_util::CreateDirectory(temp_index_filename.DirName())) { + LOG(ERROR) << "Could not create a directory to hold the index file"; + return; + } + if (!WritePickleFile(pickle.get(), temp_index_filename)) { + LOG(ERROR) << "Failed to write the temporary index file"; + return; + } + } + + // Atomically rename the temporary index file to become the real one. + bool result = base::ReplaceFile(temp_index_filename, index_filename, NULL); + DCHECK(result); + + if (app_on_background) { + SIMPLE_CACHE_UMA(TIMES, + "IndexWriteToDiskTime.Background", cache_type, + (base::TimeTicks::Now() - start_time)); + } else { + SIMPLE_CACHE_UMA(TIMES, + "IndexWriteToDiskTime.Foreground", cache_type, + (base::TimeTicks::Now() - start_time)); + } +} + bool SimpleIndexFile::IndexMetadata::CheckIndexMetadata() { return number_of_entries_ <= kMaxEntiresInIndex && - magic_number_ == disk_cache::kSimpleIndexMagicNumber && - version_ == disk_cache::kSimpleVersion; + magic_number_ == kSimpleIndexMagicNumber && + version_ == kSimpleVersion; } SimpleIndexFile::SimpleIndexFile( base::SingleThreadTaskRunner* cache_thread, base::TaskRunner* worker_pool, + net::CacheType cache_type, const base::FilePath& cache_directory) : cache_thread_(cache_thread), worker_pool_(worker_pool), + cache_type_(cache_type), cache_directory_(cache_directory), - index_file_(cache_directory_.AppendASCII(kIndexFileName)), - temp_index_file_(cache_directory_.AppendASCII(kTempIndexFileName)) { + index_file_(cache_directory_.AppendASCII(kIndexDirectory) + .AppendASCII(kIndexFileName)), + temp_index_file_(cache_directory_.AppendASCII(kIndexDirectory) + .AppendASCII(kTempIndexFileName)) { } SimpleIndexFile::~SimpleIndexFile() {} @@ -186,6 +258,7 @@ void SimpleIndexFile::LoadIndexEntries(base::Time cache_last_modified, const base::Closure& callback, SimpleIndexLoadResult* out_result) { base::Closure task = base::Bind(&SimpleIndexFile::SyncLoadIndexEntries, + cache_type_, cache_last_modified, cache_directory_, index_file_, out_result); worker_pool_->PostTaskAndReply(FROM_HERE, task, callback); @@ -198,7 +271,9 @@ void SimpleIndexFile::WriteToDisk(const SimpleIndex::EntrySet& entry_set, IndexMetadata index_metadata(entry_set.size(), cache_size); scoped_ptr<Pickle> pickle = Serialize(index_metadata, entry_set); cache_thread_->PostTask(FROM_HERE, base::Bind( - &WriteToDiskInternal, + &SimpleIndexFile::SyncWriteToDisk, + cache_type_, + cache_directory_, index_file_, temp_index_file_, base::Passed(&pickle), @@ -206,97 +281,58 @@ void SimpleIndexFile::WriteToDisk(const SimpleIndex::EntrySet& entry_set, app_on_background)); } -void SimpleIndexFile::DoomEntrySet( - scoped_ptr<std::vector<uint64> > entry_hashes, - const net::CompletionCallback& reply_callback) { - PostTaskAndReplyWithResult( - worker_pool_, - FROM_HERE, - base::Bind(&SimpleSynchronousEntry::DoomEntrySet, - base::Passed(entry_hashes.Pass()), cache_directory_), - base::Bind(&DoomEntrySetReply, reply_callback)); -} - // static void SimpleIndexFile::SyncLoadIndexEntries( + net::CacheType cache_type, base::Time cache_last_modified, const base::FilePath& cache_directory, const base::FilePath& index_file_path, SimpleIndexLoadResult* out_result) { - // TODO(felipeg): probably could load a stale index and use it for something. - const SimpleIndex::EntrySet& entries = out_result->entries; - - const bool index_file_exists = base::PathExists(index_file_path); - - // Used in histograms. Please only add new values at the end. - enum { - INDEX_STATE_CORRUPT = 0, - INDEX_STATE_STALE = 1, - INDEX_STATE_FRESH = 2, - INDEX_STATE_FRESH_CONCURRENT_UPDATES = 3, - INDEX_STATE_MAX = 4, - } index_file_state; - - // Only load if the index is not stale. - if (IsIndexFileStale(cache_last_modified, index_file_path)) { - index_file_state = INDEX_STATE_STALE; - } else { - index_file_state = INDEX_STATE_FRESH; - base::Time latest_dir_mtime; - if (simple_util::GetMTime(cache_directory, &latest_dir_mtime) && - IsIndexFileStale(latest_dir_mtime, index_file_path)) { - // A file operation has updated the directory since we last looked at it - // during backend initialization. - index_file_state = INDEX_STATE_FRESH_CONCURRENT_UPDATES; - } - - const base::TimeTicks start = base::TimeTicks::Now(); - SyncLoadFromDisk(index_file_path, out_result); - UMA_HISTOGRAM_TIMES("SimpleCache.IndexLoadTime", - base::TimeTicks::Now() - start); - UMA_HISTOGRAM_COUNTS("SimpleCache.IndexEntriesLoaded", - out_result->did_load ? entries.size() : 0); - if (!out_result->did_load) - index_file_state = INDEX_STATE_CORRUPT; - } - UMA_HISTOGRAM_ENUMERATION("SimpleCache.IndexFileStateOnLoad", - index_file_state, - INDEX_STATE_MAX); + // Load the index and find its age. + base::Time last_cache_seen_by_index; + SyncLoadFromDisk(index_file_path, &last_cache_seen_by_index, out_result); + // Consider the index loaded if it is fresh. + const bool index_file_existed = base::PathExists(index_file_path); if (!out_result->did_load) { - const base::TimeTicks start = base::TimeTicks::Now(); - SyncRestoreFromDisk(cache_directory, index_file_path, out_result); - UMA_HISTOGRAM_MEDIUM_TIMES("SimpleCache.IndexRestoreTime", - base::TimeTicks::Now() - start); - UMA_HISTOGRAM_COUNTS("SimpleCache.IndexEntriesRestored", - entries.size()); + if (index_file_existed) + UmaRecordIndexFileState(INDEX_STATE_CORRUPT, cache_type); + } else { + if (cache_last_modified <= last_cache_seen_by_index) { + base::Time latest_dir_mtime; + simple_util::GetMTime(cache_directory, &latest_dir_mtime); + if (LegacyIsIndexFileStale(latest_dir_mtime, index_file_path)) { + UmaRecordIndexFileState(INDEX_STATE_FRESH_CONCURRENT_UPDATES, + cache_type); + } else { + UmaRecordIndexFileState(INDEX_STATE_FRESH, cache_type); + } + UmaRecordIndexInitMethod(INITIALIZE_METHOD_LOADED, cache_type); + return; + } + UmaRecordIndexFileState(INDEX_STATE_STALE, cache_type); } - // Used in histograms. Please only add new values at the end. - enum { - INITIALIZE_METHOD_RECOVERED = 0, - INITIALIZE_METHOD_LOADED = 1, - INITIALIZE_METHOD_NEWCACHE = 2, - INITIALIZE_METHOD_MAX = 3, - }; - int initialize_method; - if (index_file_exists) { - if (out_result->flush_required) - initialize_method = INITIALIZE_METHOD_RECOVERED; - else - initialize_method = INITIALIZE_METHOD_LOADED; + // Reconstruct the index by scanning the disk for entries. + const base::TimeTicks start = base::TimeTicks::Now(); + SyncRestoreFromDisk(cache_directory, index_file_path, out_result); + SIMPLE_CACHE_UMA(MEDIUM_TIMES, "IndexRestoreTime", cache_type, + base::TimeTicks::Now() - start); + SIMPLE_CACHE_UMA(COUNTS, "IndexEntriesRestored", cache_type, + out_result->entries.size()); + if (index_file_existed) { + UmaRecordIndexInitMethod(INITIALIZE_METHOD_RECOVERED, cache_type); } else { - UMA_HISTOGRAM_COUNTS("SimpleCache.IndexCreatedEntryCount", - entries.size()); - initialize_method = INITIALIZE_METHOD_NEWCACHE; + UmaRecordIndexInitMethod(INITIALIZE_METHOD_NEWCACHE, cache_type); + SIMPLE_CACHE_UMA(COUNTS, + "IndexCreatedEntryCount", cache_type, + out_result->entries.size()); } - - UMA_HISTOGRAM_ENUMERATION("SimpleCache.IndexInitializeMethod", - initialize_method, INITIALIZE_METHOD_MAX); } // static void SimpleIndexFile::SyncLoadFromDisk(const base::FilePath& index_filename, + base::Time* out_last_cache_seen_by_index, SimpleIndexLoadResult* out_result) { out_result->Reset(); @@ -309,7 +345,9 @@ void SimpleIndexFile::SyncLoadFromDisk(const base::FilePath& index_filename, SimpleIndexFile::Deserialize( reinterpret_cast<const char*>(index_file_map.data()), - index_file_map.length(), out_result); + index_file_map.length(), + out_last_cache_seen_by_index, + out_result); if (!out_result->did_load) base::DeleteFile(index_filename, false); @@ -327,14 +365,12 @@ scoped_ptr<Pickle> SimpleIndexFile::Serialize( pickle->WriteUInt64(it->first); it->second.Serialize(pickle.get()); } - SimpleIndexFile::PickleHeader* header_p = - pickle->headerT<SimpleIndexFile::PickleHeader>(); - header_p->crc = CalculatePickleCRC(*pickle); return pickle.Pass(); } // static void SimpleIndexFile::Deserialize(const char* data, int data_len, + base::Time* out_cache_last_modified, SimpleIndexLoadResult* out_result) { DCHECK(data); @@ -348,7 +384,6 @@ void SimpleIndexFile::Deserialize(const char* data, int data_len, } PickleIterator pickle_it(pickle); - SimpleIndexFile::PickleHeader* header_p = pickle.headerT<SimpleIndexFile::PickleHeader>(); const uint32 crc_read = header_p->crc; @@ -386,6 +421,14 @@ void SimpleIndexFile::Deserialize(const char* data, int data_len, SimpleIndex::InsertInEntrySet(hash_key, entry_metadata, entries); } + int64 cache_last_modified; + if (!pickle_it.ReadInt64(&cache_last_modified)) { + entries->clear(); + return; + } + DCHECK(out_cache_last_modified); + *out_cache_last_modified = base::Time::FromInternalValue(cache_last_modified); + out_result->did_load = true; } @@ -399,10 +442,6 @@ void SimpleIndexFile::SyncRestoreFromDisk( out_result->Reset(); SimpleIndex::EntrySet* entries = &out_result->entries; - // TODO(felipeg,gavinp): Fix this once we have a one-file per entry format. - COMPILE_ASSERT(kSimpleEntryFileCount == 3, - file_pattern_must_match_file_count); - const bool did_succeed = TraverseCacheDirectory( cache_directory, base::Bind(&ProcessEntryFile, entries)); if (!did_succeed) { @@ -416,8 +455,9 @@ void SimpleIndexFile::SyncRestoreFromDisk( } // static -bool SimpleIndexFile::IsIndexFileStale(base::Time cache_last_modified, - const base::FilePath& index_file_path) { +bool SimpleIndexFile::LegacyIsIndexFileStale( + base::Time cache_last_modified, + const base::FilePath& index_file_path) { base::Time index_mtime; if (!simple_util::GetMTime(index_file_path, &index_mtime)) return true; diff --git a/chromium/net/disk_cache/simple/simple_index_file.h b/chromium/net/disk_cache/simple/simple_index_file.h index e5fc85d69a2..ce19e2bb56f 100644 --- a/chromium/net/disk_cache/simple/simple_index_file.h +++ b/chromium/net/disk_cache/simple/simple_index_file.h @@ -16,6 +16,7 @@ #include "base/memory/scoped_ptr.h" #include "base/pickle.h" #include "base/port.h" +#include "net/base/cache_type.h" #include "net/base/net_export.h" #include "net/disk_cache/simple/simple_index.h" @@ -75,6 +76,7 @@ class NET_EXPORT_PRIVATE SimpleIndexFile { SimpleIndexFile(base::SingleThreadTaskRunner* cache_thread, base::TaskRunner* worker_pool, + net::CacheType cache_type, const base::FilePath& cache_directory); virtual ~SimpleIndexFile(); @@ -89,11 +91,6 @@ class NET_EXPORT_PRIVATE SimpleIndexFile { const base::TimeTicks& start, bool app_on_background); - // Doom the entries specified in |entry_hashes|, calling |reply_callback| - // with the result on the current thread when done. - virtual void DoomEntrySet(scoped_ptr<std::vector<uint64> > entry_hashes, - const base::Callback<void(int)>& reply_callback); - private: friend class WrappedSimpleIndexFile; @@ -105,25 +102,35 @@ class NET_EXPORT_PRIVATE SimpleIndexFile { static const int kExtraSizeForMerge = 512; // Synchronous (IO performing) implementation of LoadIndexEntries. - static void SyncLoadIndexEntries(base::Time cache_last_modified, + static void SyncLoadIndexEntries(net::CacheType cache_type, + base::Time cache_last_modified, const base::FilePath& cache_directory, const base::FilePath& index_file_path, SimpleIndexLoadResult* out_result); - // Load the index file from disk returning an EntrySet. Upon failure, returns - // NULL. + // Load the index file from disk returning an EntrySet. static void SyncLoadFromDisk(const base::FilePath& index_filename, + base::Time* out_last_cache_seen_by_index, SimpleIndexLoadResult* out_result); // Returns a scoped_ptr for a newly allocated Pickle containing the serialized - // data to be written to a file. + // data to be written to a file. Note: the pickle is not in a consistent state + // immediately after calling this menthod, one needs to call + // SerializeFinalData to make it ready to write to a file. static scoped_ptr<Pickle> Serialize( const SimpleIndexFile::IndexMetadata& index_metadata, const SimpleIndex::EntrySet& entries); + // Appends cache modification time data to the serialized format. This is + // performed on a thread accessing the disk. It is not combined with the main + // serialization path to avoid extra thread hops or copying the pickle to the + // worker thread. + static bool SerializeFinalData(base::Time cache_modified, Pickle* pickle); + // Given the contents of an index file |data| of length |data_len|, returns // the corresponding EntrySet. Returns NULL on error. static void Deserialize(const char* data, int data_len, + base::Time* out_cache_last_modified, SimpleIndexLoadResult* out_result); // Implemented either in simple_index_file_posix.cc or @@ -135,6 +142,15 @@ class NET_EXPORT_PRIVATE SimpleIndexFile { const base::FilePath& cache_path, const EntryFileCallback& entry_file_callback); + // Writes the index file to disk atomically. + static void SyncWriteToDisk(net::CacheType cache_type, + const base::FilePath& cache_directory, + const base::FilePath& index_filename, + const base::FilePath& temp_index_filename, + scoped_ptr<Pickle> pickle, + const base::TimeTicks& start_time, + bool app_on_background); + // Scan the index directory for entries, returning an EntrySet of all entries // found. static void SyncRestoreFromDisk(const base::FilePath& cache_directory, @@ -142,9 +158,11 @@ class NET_EXPORT_PRIVATE SimpleIndexFile { SimpleIndexLoadResult* out_result); // Determines if an index file is stale relative to the time of last - // modification of the cache directory. - static bool IsIndexFileStale(base::Time cache_last_modified, - const base::FilePath& index_file_path); + // modification of the cache directory. Obsolete, used only for a histogram to + // compare with the new method. + // TODO(pasko): remove this method after getting enough data. + static bool LegacyIsIndexFileStale(base::Time cache_last_modified, + const base::FilePath& index_file_path); struct PickleHeader : public Pickle::Header { uint32 crc; @@ -152,10 +170,15 @@ class NET_EXPORT_PRIVATE SimpleIndexFile { const scoped_refptr<base::SingleThreadTaskRunner> cache_thread_; const scoped_refptr<base::TaskRunner> worker_pool_; + const net::CacheType cache_type_; const base::FilePath cache_directory_; const base::FilePath index_file_; const base::FilePath temp_index_file_; + static const char kIndexDirectory[]; + static const char kIndexFileName[]; + static const char kTempIndexFileName[]; + DISALLOW_COPY_AND_ASSIGN(SimpleIndexFile); }; diff --git a/chromium/net/disk_cache/simple/simple_index_file_unittest.cc b/chromium/net/disk_cache/simple/simple_index_file_unittest.cc index bf7ee83c30e..0e9c2e8155b 100644 --- a/chromium/net/disk_cache/simple/simple_index_file_unittest.cc +++ b/chromium/net/disk_cache/simple/simple_index_file_unittest.cc @@ -12,6 +12,8 @@ #include "base/run_loop.h" #include "base/strings/stringprintf.h" #include "base/time/time.h" +#include "net/base/cache_type.h" +#include "net/disk_cache/simple/simple_backend_version.h" #include "net/disk_cache/simple/simple_entry_format.h" #include "net/disk_cache/simple/simple_index.h" #include "net/disk_cache/simple/simple_index_file.h" @@ -24,6 +26,11 @@ using disk_cache::SimpleIndex; namespace disk_cache { +// The Simple Cache backend requires a few guarantees from the filesystem like +// atomic renaming of recently open files. Those guarantees are not provided in +// general on Windows. +#if defined(OS_POSIX) + TEST(IndexMetadataTest, Basics) { SimpleIndexFile::IndexMetadata index_metadata; @@ -57,12 +64,14 @@ TEST(IndexMetadataTest, Serialize) { class WrappedSimpleIndexFile : public SimpleIndexFile { public: using SimpleIndexFile::Deserialize; - using SimpleIndexFile::IsIndexFileStale; + using SimpleIndexFile::LegacyIsIndexFileStale; using SimpleIndexFile::Serialize; + using SimpleIndexFile::SerializeFinalData; explicit WrappedSimpleIndexFile(const base::FilePath& index_file_directory) : SimpleIndexFile(base::MessageLoopProxy::current().get(), base::MessageLoopProxy::current().get(), + net::DISK_CACHE, index_file_directory) {} virtual ~WrappedSimpleIndexFile() { } @@ -70,13 +79,19 @@ class WrappedSimpleIndexFile : public SimpleIndexFile { const base::FilePath& GetIndexFilePath() const { return index_file_; } + + bool CreateIndexFileDirectory() const { + return file_util::CreateDirectory(index_file_.DirName()); + } }; class SimpleIndexFileTest : public testing::Test { public: bool CompareTwoEntryMetadata(const EntryMetadata& a, const EntryMetadata& b) { - return a.last_used_time_ == b.last_used_time_ && - a.entry_size_ == b.entry_size_; + return + a.last_used_time_seconds_since_epoch_ == + b.last_used_time_seconds_since_epoch_ && + a.entry_size_ == b.entry_size_; } protected: @@ -108,20 +123,23 @@ TEST_F(SimpleIndexFileTest, Serialize) { 456); for (size_t i = 0; i < kNumHashes; ++i) { uint64 hash = kHashes[i]; - metadata_entries[i] = - EntryMetadata(Time::FromInternalValue(hash), hash); + metadata_entries[i] = EntryMetadata(Time(), hash); SimpleIndex::InsertInEntrySet(hash, metadata_entries[i], &entries); } scoped_ptr<Pickle> pickle = WrappedSimpleIndexFile::Serialize( index_metadata, entries); EXPECT_TRUE(pickle.get() != NULL); - + base::Time now = base::Time::Now(); + EXPECT_TRUE(WrappedSimpleIndexFile::SerializeFinalData(now, pickle.get())); + base::Time when_index_last_saw_cache; SimpleIndexLoadResult deserialize_result; WrappedSimpleIndexFile::Deserialize(static_cast<const char*>(pickle->data()), - pickle->size(), - &deserialize_result); + pickle->size(), + &when_index_last_saw_cache, + &deserialize_result); EXPECT_TRUE(deserialize_result.did_load); + EXPECT_EQ(now, when_index_last_saw_cache); const SimpleIndex::EntrySet& new_entries = deserialize_result.entries; EXPECT_EQ(entries.size(), new_entries.size()); @@ -132,7 +150,7 @@ TEST_F(SimpleIndexFileTest, Serialize) { } } -TEST_F(SimpleIndexFileTest, IsIndexFileStale) { +TEST_F(SimpleIndexFileTest, LegacyIsIndexFileStale) { base::ScopedTempDir cache_dir; ASSERT_TRUE(cache_dir.CreateUniqueTempDir()); base::Time cache_mtime; @@ -140,34 +158,34 @@ TEST_F(SimpleIndexFileTest, IsIndexFileStale) { ASSERT_TRUE(simple_util::GetMTime(cache_path, &cache_mtime)); WrappedSimpleIndexFile simple_index_file(cache_path); + ASSERT_TRUE(simple_index_file.CreateIndexFileDirectory()); const base::FilePath& index_path = simple_index_file.GetIndexFilePath(); - EXPECT_TRUE(WrappedSimpleIndexFile::IsIndexFileStale(cache_mtime, - index_path)); + EXPECT_TRUE( + WrappedSimpleIndexFile::LegacyIsIndexFileStale(cache_mtime, index_path)); const std::string kDummyData = "nothing to be seen here"; EXPECT_EQ(static_cast<int>(kDummyData.size()), file_util::WriteFile(index_path, kDummyData.data(), kDummyData.size())); ASSERT_TRUE(simple_util::GetMTime(cache_path, &cache_mtime)); - EXPECT_FALSE(WrappedSimpleIndexFile::IsIndexFileStale(cache_mtime, - index_path)); + EXPECT_FALSE( + WrappedSimpleIndexFile::LegacyIsIndexFileStale(cache_mtime, index_path)); const base::Time past_time = base::Time::Now() - base::TimeDelta::FromSeconds(10); EXPECT_TRUE(file_util::TouchFile(index_path, past_time, past_time)); EXPECT_TRUE(file_util::TouchFile(cache_path, past_time, past_time)); ASSERT_TRUE(simple_util::GetMTime(cache_path, &cache_mtime)); - EXPECT_FALSE(WrappedSimpleIndexFile::IsIndexFileStale(cache_mtime, - index_path)); - const base::Time even_older = - past_time - base::TimeDelta::FromSeconds(10); + EXPECT_FALSE( + WrappedSimpleIndexFile::LegacyIsIndexFileStale(cache_mtime, index_path)); + const base::Time even_older = past_time - base::TimeDelta::FromSeconds(10); EXPECT_TRUE(file_util::TouchFile(index_path, even_older, even_older)); - EXPECT_TRUE(WrappedSimpleIndexFile::IsIndexFileStale(cache_mtime, - index_path)); - + EXPECT_TRUE( + WrappedSimpleIndexFile::LegacyIsIndexFileStale(cache_mtime, index_path)); } -TEST_F(SimpleIndexFileTest, WriteThenLoadIndex) { +// This test is flaky, see http://crbug.com/255775. +TEST_F(SimpleIndexFileTest, DISABLED_WriteThenLoadIndex) { base::ScopedTempDir cache_dir; ASSERT_TRUE(cache_dir.CreateUniqueTempDir()); @@ -177,8 +195,7 @@ TEST_F(SimpleIndexFileTest, WriteThenLoadIndex) { EntryMetadata metadata_entries[kNumHashes]; for (size_t i = 0; i < kNumHashes; ++i) { uint64 hash = kHashes[i]; - metadata_entries[i] = - EntryMetadata(Time::FromInternalValue(hash), hash); + metadata_entries[i] = EntryMetadata(Time(), hash); SimpleIndex::InsertInEntrySet(hash, metadata_entries[i], &entries); } @@ -216,17 +233,17 @@ TEST_F(SimpleIndexFileTest, LoadCorruptIndex) { ASSERT_TRUE(cache_dir.CreateUniqueTempDir()); WrappedSimpleIndexFile simple_index_file(cache_dir.path()); + ASSERT_TRUE(simple_index_file.CreateIndexFileDirectory()); const base::FilePath& index_path = simple_index_file.GetIndexFilePath(); const std::string kDummyData = "nothing to be seen here"; - EXPECT_EQ(static_cast<int>(kDummyData.size()), - file_util::WriteFile(index_path, - kDummyData.data(), - kDummyData.size())); + EXPECT_EQ( + implicit_cast<int>(kDummyData.size()), + file_util::WriteFile(index_path, kDummyData.data(), kDummyData.size())); base::Time fake_cache_mtime; ASSERT_TRUE(simple_util::GetMTime(simple_index_file.GetIndexFilePath(), &fake_cache_mtime)); - EXPECT_FALSE(WrappedSimpleIndexFile::IsIndexFileStale(fake_cache_mtime, - index_path)); + EXPECT_FALSE(WrappedSimpleIndexFile::LegacyIsIndexFileStale(fake_cache_mtime, + index_path)); SimpleIndexLoadResult load_index_result; simple_index_file.LoadIndexEntries(fake_cache_mtime, @@ -240,4 +257,6 @@ TEST_F(SimpleIndexFileTest, LoadCorruptIndex) { EXPECT_TRUE(load_index_result.flush_required); } +#endif // defined(OS_POSIX) + } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_index_unittest.cc b/chromium/net/disk_cache/simple/simple_index_unittest.cc index 0c845b21318..47ae24f62fe 100644 --- a/chromium/net/disk_cache/simple/simple_index_unittest.cc +++ b/chromium/net/disk_cache/simple/simple_index_unittest.cc @@ -2,6 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include <algorithm> +#include <functional> + #include "base/files/scoped_temp_dir.h" #include "base/hash.h" #include "base/logging.h" @@ -12,24 +15,23 @@ #include "base/task_runner.h" #include "base/threading/platform_thread.h" #include "base/time/time.h" +#include "net/base/cache_type.h" #include "net/disk_cache/simple/simple_index.h" +#include "net/disk_cache/simple/simple_index_delegate.h" #include "net/disk_cache/simple/simple_index_file.h" +#include "net/disk_cache/simple/simple_test_util.h" #include "net/disk_cache/simple/simple_util.h" #include "testing/gtest/include/gtest/gtest.h" +namespace disk_cache { namespace { -const int64 kTestLastUsedTimeInternal = 12345; const base::Time kTestLastUsedTime = - base::Time::FromInternalValue(kTestLastUsedTimeInternal); -const uint64 kTestEntrySize = 789; -const uint64 kKey1Hash = disk_cache::simple_util::GetEntryHashKey("key1"); -const uint64 kKey2Hash = disk_cache::simple_util::GetEntryHashKey("key2"); -const uint64 kKey3Hash = disk_cache::simple_util::GetEntryHashKey("key3"); + base::Time::UnixEpoch() + base::TimeDelta::FromDays(20); +const int kTestEntrySize = 789; } // namespace -namespace disk_cache { class EntryMetadataTest : public testing::Test { public: @@ -38,7 +40,10 @@ class EntryMetadataTest : public testing::Test { } void CheckEntryMetadataValues(const EntryMetadata& entry_metadata) { - EXPECT_EQ(kTestLastUsedTime, entry_metadata.GetLastUsedTime()); + EXPECT_LT(kTestLastUsedTime - base::TimeDelta::FromSeconds(2), + entry_metadata.GetLastUsedTime()); + EXPECT_GT(kTestLastUsedTime + base::TimeDelta::FromSeconds(2), + entry_metadata.GetLastUsedTime()); EXPECT_EQ(kTestEntrySize, entry_metadata.GetEntrySize()); } }; @@ -47,10 +52,9 @@ class MockSimpleIndexFile : public SimpleIndexFile, public base::SupportsWeakPtr<MockSimpleIndexFile> { public: MockSimpleIndexFile() - : SimpleIndexFile(NULL, NULL, base::FilePath()), + : SimpleIndexFile(NULL, NULL, net::DISK_CACHE, base::FilePath()), load_result_(NULL), load_index_entries_calls_(0), - doom_entry_set_calls_(0), disk_writes_(0) {} virtual void LoadIndexEntries( @@ -70,14 +74,6 @@ class MockSimpleIndexFile : public SimpleIndexFile, disk_write_entry_set_ = entry_set; } - virtual void DoomEntrySet( - scoped_ptr<std::vector<uint64> > entry_hashes, - const base::Callback<void(int)>& reply_callback) OVERRIDE { - last_doom_entry_hashes_ = *entry_hashes.get(); - last_doom_reply_callback_ = reply_callback; - ++doom_entry_set_calls_; - } - void GetAndResetDiskWriteEntrySet(SimpleIndex::EntrySet* entry_set) { entry_set->swap(disk_write_entry_set_); } @@ -86,55 +82,65 @@ class MockSimpleIndexFile : public SimpleIndexFile, SimpleIndexLoadResult* load_result() const { return load_result_; } int load_index_entries_calls() const { return load_index_entries_calls_; } int disk_writes() const { return disk_writes_; } - const std::vector<uint64>& last_doom_entry_hashes() const { - return last_doom_entry_hashes_; - } - int doom_entry_set_calls() const { return doom_entry_set_calls_; } private: base::Closure load_callback_; SimpleIndexLoadResult* load_result_; int load_index_entries_calls_; - std::vector<uint64> last_doom_entry_hashes_; - int doom_entry_set_calls_; - base::Callback<void(int)> last_doom_reply_callback_; int disk_writes_; SimpleIndex::EntrySet disk_write_entry_set_; }; -class SimpleIndexTest : public testing::Test { - public: +class SimpleIndexTest : public testing::Test, public SimpleIndexDelegate { + protected: + SimpleIndexTest() + : hashes_(base::Bind(&HashesInitializer)), + doom_entries_calls_(0) {} + + static uint64 HashesInitializer(size_t hash_index) { + return disk_cache::simple_util::GetEntryHashKey( + base::StringPrintf("key%d", static_cast<int>(hash_index))); + } + virtual void SetUp() OVERRIDE { scoped_ptr<MockSimpleIndexFile> index_file(new MockSimpleIndexFile()); index_file_ = index_file->AsWeakPtr(); - index_.reset(new SimpleIndex(NULL, base::FilePath(), + index_.reset(new SimpleIndex(NULL, this, net::DISK_CACHE, index_file.PassAs<SimpleIndexFile>())); index_->Initialize(base::Time()); } void WaitForTimeChange() { - base::Time now(base::Time::Now()); - + const base::Time initial_time = base::Time::Now(); do { base::PlatformThread::YieldCurrentThread(); - } while (now == base::Time::Now()); + } while (base::Time::Now() - + initial_time < base::TimeDelta::FromSeconds(1)); + } + + // From SimpleIndexDelegate: + virtual void DoomEntries(std::vector<uint64>* entry_hashes, + const net::CompletionCallback& callback) OVERRIDE { + std::for_each(entry_hashes->begin(), entry_hashes->end(), + std::bind1st(std::mem_fun(&SimpleIndex::Remove), + index_.get())); + last_doom_entry_hashes_ = *entry_hashes; + ++doom_entries_calls_; } // Redirect to allow single "friend" declaration in base class. - bool GetEntryForTesting(const std::string& key, EntryMetadata* metadata) { - const uint64 hash_key = simple_util::GetEntryHashKey(key); - SimpleIndex::EntrySet::iterator it = index_->entries_set_.find(hash_key); + bool GetEntryForTesting(uint64 key, EntryMetadata* metadata) { + SimpleIndex::EntrySet::iterator it = index_->entries_set_.find(key); if (index_->entries_set_.end() == it) return false; *metadata = it->second; return true; } - void InsertIntoIndexFileReturn(const std::string& key, + void InsertIntoIndexFileReturn(uint64 hash_key, base::Time last_used_time, - uint64 entry_size) { - uint64 hash_key(simple_util::GetEntryHashKey(key)); + int entry_size) { index_file_->load_result()->entries.insert(std::make_pair( hash_key, EntryMetadata(last_used_time, entry_size))); } @@ -148,22 +154,35 @@ class SimpleIndexTest : public testing::Test { SimpleIndex* index() { return index_.get(); } const MockSimpleIndexFile* index_file() const { return index_file_.get(); } - protected: + const std::vector<uint64>& last_doom_entry_hashes() const { + return last_doom_entry_hashes_; + } + int doom_entries_calls() const { return doom_entries_calls_; } + + + const simple_util::ImmutableArray<uint64, 16> hashes_; scoped_ptr<SimpleIndex> index_; base::WeakPtr<MockSimpleIndexFile> index_file_; + + std::vector<uint64> last_doom_entry_hashes_; + int doom_entries_calls_; }; TEST_F(EntryMetadataTest, Basics) { EntryMetadata entry_metadata; - EXPECT_EQ(base::Time::FromInternalValue(0), entry_metadata.GetLastUsedTime()); - EXPECT_EQ(size_t(0), entry_metadata.GetEntrySize()); + EXPECT_EQ(base::Time(), entry_metadata.GetLastUsedTime()); + EXPECT_EQ(0, entry_metadata.GetEntrySize()); entry_metadata = NewEntryMetadataWithValues(); CheckEntryMetadataValues(entry_metadata); - const base::Time new_time = base::Time::FromInternalValue(5); + const base::Time new_time = base::Time::Now(); entry_metadata.SetLastUsedTime(new_time); - EXPECT_EQ(new_time, entry_metadata.GetLastUsedTime()); + + EXPECT_LT(new_time - base::TimeDelta::FromSeconds(2), + entry_metadata.GetLastUsedTime()); + EXPECT_GT(new_time + base::TimeDelta::FromSeconds(2), + entry_metadata.GetLastUsedTime()); } TEST_F(EntryMetadataTest, Serialize) { @@ -181,31 +200,31 @@ TEST_F(EntryMetadataTest, Serialize) { TEST_F(SimpleIndexTest, IndexSizeCorrectOnMerge) { typedef disk_cache::SimpleIndex::EntrySet EntrySet; index()->SetMaxSize(100); - index()->Insert("two"); - index()->UpdateEntrySize("two", 2); - index()->Insert("five"); - index()->UpdateEntrySize("five", 5); - index()->Insert("seven"); - index()->UpdateEntrySize("seven", 7); - EXPECT_EQ(14U, index()->cache_size_); + index()->Insert(hashes_.at<2>()); + index()->UpdateEntrySize(hashes_.at<2>(), 2); + index()->Insert(hashes_.at<3>()); + index()->UpdateEntrySize(hashes_.at<3>(), 3); + index()->Insert(hashes_.at<4>()); + index()->UpdateEntrySize(hashes_.at<4>(), 4); + EXPECT_EQ(9U, index()->cache_size_); { scoped_ptr<SimpleIndexLoadResult> result(new SimpleIndexLoadResult()); result->did_load = true; index()->MergeInitializingSet(result.Pass()); } - EXPECT_EQ(14U, index()->cache_size_); + EXPECT_EQ(9U, index()->cache_size_); { scoped_ptr<SimpleIndexLoadResult> result(new SimpleIndexLoadResult()); result->did_load = true; - const uint64 new_hash_key = simple_util::GetEntryHashKey("eleven"); + const uint64 new_hash_key = hashes_.at<11>(); result->entries.insert( std::make_pair(new_hash_key, EntryMetadata(base::Time::Now(), 11))); - const uint64 redundant_hash_key = simple_util::GetEntryHashKey("seven"); + const uint64 redundant_hash_key = hashes_.at<4>(); result->entries.insert(std::make_pair(redundant_hash_key, - EntryMetadata(base::Time::Now(), 7))); + EntryMetadata(base::Time::Now(), 4))); index()->MergeInitializingSet(result.Pass()); } - EXPECT_EQ(2U + 5U + 7U + 11U, index()->cache_size_); + EXPECT_EQ(2U + 3U + 4U + 11U, index()->cache_size_); } // State of index changes as expected with an insert and a remove. @@ -213,22 +232,22 @@ TEST_F(SimpleIndexTest, BasicInsertRemove) { // Confirm blank state. EntryMetadata metadata; EXPECT_EQ(base::Time(), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); // Confirm state after insert. - index()->Insert("key1"); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata)); + index()->Insert(hashes_.at<1>()); + ASSERT_TRUE(GetEntryForTesting(hashes_.at<1>(), &metadata)); base::Time now(base::Time::Now()); EXPECT_LT(now - base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); EXPECT_GT(now + base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); // Confirm state after remove. metadata = EntryMetadata(); - index()->Remove("key1"); - EXPECT_FALSE(GetEntryForTesting("key1", &metadata)); + index()->Remove(hashes_.at<1>()); + EXPECT_FALSE(GetEntryForTesting(hashes_.at<1>(), &metadata)); EXPECT_EQ(base::Time(), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); } TEST_F(SimpleIndexTest, Has) { @@ -237,20 +256,21 @@ TEST_F(SimpleIndexTest, Has) { EXPECT_EQ(1, index_file_->load_index_entries_calls()); // Confirm "Has()" always returns true before the callback is called. - EXPECT_TRUE(index()->Has(kKey1Hash)); - index()->Insert("key1"); - EXPECT_TRUE(index()->Has(kKey1Hash)); - index()->Remove("key1"); + const uint64 kHash1 = hashes_.at<1>(); + EXPECT_TRUE(index()->Has(kHash1)); + index()->Insert(kHash1); + EXPECT_TRUE(index()->Has(kHash1)); + index()->Remove(kHash1); // TODO(rdsmith): Maybe return false on explicitly removed entries? - EXPECT_TRUE(index()->Has(kKey1Hash)); + EXPECT_TRUE(index()->Has(kHash1)); ReturnIndexFile(); // Confirm "Has() returns conditionally now. - EXPECT_FALSE(index()->Has(kKey1Hash)); - index()->Insert("key1"); - EXPECT_TRUE(index()->Has(kKey1Hash)); - index()->Remove("key1"); + EXPECT_FALSE(index()->Has(kHash1)); + index()->Insert(kHash1); + EXPECT_TRUE(index()->Has(kHash1)); + index()->Remove(kHash1); } TEST_F(SimpleIndexTest, UseIfExists) { @@ -260,37 +280,38 @@ TEST_F(SimpleIndexTest, UseIfExists) { // Confirm "UseIfExists()" always returns true before the callback is called // and updates mod time if the entry was really there. + const uint64 kHash1 = hashes_.at<1>(); EntryMetadata metadata1, metadata2; - EXPECT_TRUE(index()->UseIfExists("key1")); - EXPECT_FALSE(GetEntryForTesting("key1", &metadata1)); - index()->Insert("key1"); - EXPECT_TRUE(index()->UseIfExists("key1")); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata1)); + EXPECT_TRUE(index()->UseIfExists(kHash1)); + EXPECT_FALSE(GetEntryForTesting(kHash1, &metadata1)); + index()->Insert(kHash1); + EXPECT_TRUE(index()->UseIfExists(kHash1)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata1)); WaitForTimeChange(); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata2)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata2)); EXPECT_EQ(metadata1.GetLastUsedTime(), metadata2.GetLastUsedTime()); - EXPECT_TRUE(index()->UseIfExists("key1")); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata2)); + EXPECT_TRUE(index()->UseIfExists(kHash1)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata2)); EXPECT_LT(metadata1.GetLastUsedTime(), metadata2.GetLastUsedTime()); - index()->Remove("key1"); - EXPECT_TRUE(index()->UseIfExists("key1")); + index()->Remove(kHash1); + EXPECT_TRUE(index()->UseIfExists(kHash1)); ReturnIndexFile(); // Confirm "UseIfExists() returns conditionally now - EXPECT_FALSE(index()->UseIfExists("key1")); - EXPECT_FALSE(GetEntryForTesting("key1", &metadata1)); - index()->Insert("key1"); - EXPECT_TRUE(index()->UseIfExists("key1")); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata1)); + EXPECT_FALSE(index()->UseIfExists(kHash1)); + EXPECT_FALSE(GetEntryForTesting(kHash1, &metadata1)); + index()->Insert(kHash1); + EXPECT_TRUE(index()->UseIfExists(kHash1)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata1)); WaitForTimeChange(); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata2)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata2)); EXPECT_EQ(metadata1.GetLastUsedTime(), metadata2.GetLastUsedTime()); - EXPECT_TRUE(index()->UseIfExists("key1")); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata2)); + EXPECT_TRUE(index()->UseIfExists(kHash1)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata2)); EXPECT_LT(metadata1.GetLastUsedTime(), metadata2.GetLastUsedTime()); - index()->Remove("key1"); - EXPECT_FALSE(index()->UseIfExists("key1")); + index()->Remove(kHash1); + EXPECT_FALSE(index()->UseIfExists(kHash1)); } TEST_F(SimpleIndexTest, UpdateEntrySize) { @@ -298,43 +319,47 @@ TEST_F(SimpleIndexTest, UpdateEntrySize) { index()->SetMaxSize(1000); - InsertIntoIndexFileReturn("key1", - now - base::TimeDelta::FromDays(2), - 475u); + const uint64 kHash1 = hashes_.at<1>(); + InsertIntoIndexFileReturn(kHash1, now - base::TimeDelta::FromDays(2), 475); ReturnIndexFile(); EntryMetadata metadata; - EXPECT_TRUE(GetEntryForTesting("key1", &metadata)); - EXPECT_EQ(now - base::TimeDelta::FromDays(2), metadata.GetLastUsedTime()); - EXPECT_EQ(475u, metadata.GetEntrySize()); - - index()->UpdateEntrySize("key1", 600u); - EXPECT_TRUE(GetEntryForTesting("key1", &metadata)); - EXPECT_EQ(600u, metadata.GetEntrySize()); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata)); + EXPECT_LT( + now - base::TimeDelta::FromDays(2) - base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_GT( + now - base::TimeDelta::FromDays(2) + base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_EQ(475, metadata.GetEntrySize()); + + index()->UpdateEntrySize(kHash1, 600u); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata)); + EXPECT_EQ(600, metadata.GetEntrySize()); EXPECT_EQ(1, index()->GetEntryCount()); } TEST_F(SimpleIndexTest, GetEntryCount) { EXPECT_EQ(0, index()->GetEntryCount()); - index()->Insert("key1"); + index()->Insert(hashes_.at<1>()); EXPECT_EQ(1, index()->GetEntryCount()); - index()->Insert("key2"); + index()->Insert(hashes_.at<2>()); EXPECT_EQ(2, index()->GetEntryCount()); - index()->Insert("key3"); + index()->Insert(hashes_.at<3>()); EXPECT_EQ(3, index()->GetEntryCount()); - index()->Insert("key3"); + index()->Insert(hashes_.at<3>()); EXPECT_EQ(3, index()->GetEntryCount()); - index()->Remove("key2"); + index()->Remove(hashes_.at<2>()); EXPECT_EQ(2, index()->GetEntryCount()); - index()->Insert("key4"); + index()->Insert(hashes_.at<4>()); EXPECT_EQ(3, index()->GetEntryCount()); - index()->Remove("key3"); + index()->Remove(hashes_.at<3>()); EXPECT_EQ(2, index()->GetEntryCount()); - index()->Remove("key3"); + index()->Remove(hashes_.at<3>()); EXPECT_EQ(2, index()->GetEntryCount()); - index()->Remove("key1"); + index()->Remove(hashes_.at<1>()); EXPECT_EQ(1, index()->GetEntryCount()); - index()->Remove("key4"); + index()->Remove(hashes_.at<4>()); EXPECT_EQ(0, index()->GetEntryCount()); } @@ -342,83 +367,97 @@ TEST_F(SimpleIndexTest, GetEntryCount) { TEST_F(SimpleIndexTest, BasicInit) { base::Time now(base::Time::Now()); - InsertIntoIndexFileReturn("key1", + InsertIntoIndexFileReturn(hashes_.at<1>(), now - base::TimeDelta::FromDays(2), 10u); - InsertIntoIndexFileReturn("key2", + InsertIntoIndexFileReturn(hashes_.at<2>(), now - base::TimeDelta::FromDays(3), 100u); ReturnIndexFile(); EntryMetadata metadata; - EXPECT_TRUE(GetEntryForTesting("key1", &metadata)); - EXPECT_EQ(now - base::TimeDelta::FromDays(2), metadata.GetLastUsedTime()); - EXPECT_EQ(10ul, metadata.GetEntrySize()); - EXPECT_TRUE(GetEntryForTesting("key2", &metadata)); - EXPECT_EQ(now - base::TimeDelta::FromDays(3), metadata.GetLastUsedTime()); - EXPECT_EQ(100ul, metadata.GetEntrySize()); + EXPECT_TRUE(GetEntryForTesting(hashes_.at<1>(), &metadata)); + EXPECT_LT( + now - base::TimeDelta::FromDays(2) - base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_GT( + now - base::TimeDelta::FromDays(2) + base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_EQ(10, metadata.GetEntrySize()); + EXPECT_TRUE(GetEntryForTesting(hashes_.at<2>(), &metadata)); + EXPECT_LT( + now - base::TimeDelta::FromDays(3) - base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_GT( + now - base::TimeDelta::FromDays(3) + base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_EQ(100, metadata.GetEntrySize()); } // Remove something that's going to come in from the loaded index. TEST_F(SimpleIndexTest, RemoveBeforeInit) { - index()->Remove("key1"); + const uint64 kHash1 = hashes_.at<1>(); + index()->Remove(kHash1); - InsertIntoIndexFileReturn("key1", + InsertIntoIndexFileReturn(kHash1, base::Time::Now() - base::TimeDelta::FromDays(2), 10u); ReturnIndexFile(); - EXPECT_FALSE(index()->Has(kKey1Hash)); + EXPECT_FALSE(index()->Has(kHash1)); } // Insert something that's going to come in from the loaded index; correct // result? TEST_F(SimpleIndexTest, InsertBeforeInit) { - index()->Insert("key1"); + const uint64 kHash1 = hashes_.at<1>(); + index()->Insert(kHash1); - InsertIntoIndexFileReturn("key1", + InsertIntoIndexFileReturn(kHash1, base::Time::Now() - base::TimeDelta::FromDays(2), 10u); ReturnIndexFile(); EntryMetadata metadata; - EXPECT_TRUE(GetEntryForTesting("key1", &metadata)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata)); base::Time now(base::Time::Now()); EXPECT_LT(now - base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); EXPECT_GT(now + base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); } // Insert and Remove something that's going to come in from the loaded index. TEST_F(SimpleIndexTest, InsertRemoveBeforeInit) { - index()->Insert("key1"); - index()->Remove("key1"); + const uint64 kHash1 = hashes_.at<1>(); + index()->Insert(kHash1); + index()->Remove(kHash1); - InsertIntoIndexFileReturn("key1", + InsertIntoIndexFileReturn(kHash1, base::Time::Now() - base::TimeDelta::FromDays(2), 10u); ReturnIndexFile(); - EXPECT_FALSE(index()->Has(kKey1Hash)); + EXPECT_FALSE(index()->Has(kHash1)); } // Insert and Remove something that's going to come in from the loaded index. TEST_F(SimpleIndexTest, RemoveInsertBeforeInit) { - index()->Remove("key1"); - index()->Insert("key1"); + const uint64 kHash1 = hashes_.at<1>(); + index()->Remove(kHash1); + index()->Insert(kHash1); - InsertIntoIndexFileReturn("key1", + InsertIntoIndexFileReturn(kHash1, base::Time::Now() - base::TimeDelta::FromDays(2), 10u); ReturnIndexFile(); EntryMetadata metadata; - EXPECT_TRUE(GetEntryForTesting("key1", &metadata)); + EXPECT_TRUE(GetEntryForTesting(kHash1, &metadata)); base::Time now(base::Time::Now()); EXPECT_LT(now - base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); EXPECT_GT(now + base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); } // Do all above tests at once + a non-conflict to test for cross-key @@ -426,81 +465,88 @@ TEST_F(SimpleIndexTest, RemoveInsertBeforeInit) { TEST_F(SimpleIndexTest, AllInitConflicts) { base::Time now(base::Time::Now()); - index()->Remove("key1"); - InsertIntoIndexFileReturn("key1", + index()->Remove(hashes_.at<1>()); + InsertIntoIndexFileReturn(hashes_.at<1>(), now - base::TimeDelta::FromDays(2), 10u); - index()->Insert("key2"); - InsertIntoIndexFileReturn("key2", + index()->Insert(hashes_.at<2>()); + InsertIntoIndexFileReturn(hashes_.at<2>(), now - base::TimeDelta::FromDays(3), 100u); - index()->Insert("key3"); - index()->Remove("key3"); - InsertIntoIndexFileReturn("key3", + index()->Insert(hashes_.at<3>()); + index()->Remove(hashes_.at<3>()); + InsertIntoIndexFileReturn(hashes_.at<3>(), now - base::TimeDelta::FromDays(4), 1000u); - index()->Remove("key4"); - index()->Insert("key4"); - InsertIntoIndexFileReturn("key4", + index()->Remove(hashes_.at<4>()); + index()->Insert(hashes_.at<4>()); + InsertIntoIndexFileReturn(hashes_.at<4>(), now - base::TimeDelta::FromDays(5), 10000u); - InsertIntoIndexFileReturn("key5", + InsertIntoIndexFileReturn(hashes_.at<5>(), now - base::TimeDelta::FromDays(6), 100000u); ReturnIndexFile(); - EXPECT_FALSE(index()->Has(kKey1Hash)); + EXPECT_FALSE(index()->Has(hashes_.at<1>())); EntryMetadata metadata; - EXPECT_TRUE(GetEntryForTesting("key2", &metadata)); + EXPECT_TRUE(GetEntryForTesting(hashes_.at<2>(), &metadata)); EXPECT_LT(now - base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); EXPECT_GT(now + base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); - EXPECT_FALSE(index()->Has(kKey3Hash)); + EXPECT_FALSE(index()->Has(hashes_.at<3>())); - EXPECT_TRUE(GetEntryForTesting("key4", &metadata)); + EXPECT_TRUE(GetEntryForTesting(hashes_.at<4>(), &metadata)); EXPECT_LT(now - base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); EXPECT_GT(now + base::TimeDelta::FromMinutes(1), metadata.GetLastUsedTime()); - EXPECT_EQ(0ul, metadata.GetEntrySize()); + EXPECT_EQ(0, metadata.GetEntrySize()); + + EXPECT_TRUE(GetEntryForTesting(hashes_.at<5>(), &metadata)); + + EXPECT_GT( + now - base::TimeDelta::FromDays(6) + base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); + EXPECT_LT( + now - base::TimeDelta::FromDays(6) - base::TimeDelta::FromSeconds(1), + metadata.GetLastUsedTime()); - EXPECT_TRUE(GetEntryForTesting("key5", &metadata)); - EXPECT_EQ(now - base::TimeDelta::FromDays(6), metadata.GetLastUsedTime()); - EXPECT_EQ(100000u, metadata.GetEntrySize()); + EXPECT_EQ(100000, metadata.GetEntrySize()); } TEST_F(SimpleIndexTest, BasicEviction) { base::Time now(base::Time::Now()); index()->SetMaxSize(1000); - InsertIntoIndexFileReturn("key1", + InsertIntoIndexFileReturn(hashes_.at<1>(), now - base::TimeDelta::FromDays(2), 475u); - index()->Insert("key2"); - index()->UpdateEntrySize("key2", 475); + index()->Insert(hashes_.at<2>()); + index()->UpdateEntrySize(hashes_.at<2>(), 475); ReturnIndexFile(); WaitForTimeChange(); - index()->Insert("key3"); + index()->Insert(hashes_.at<3>()); // Confirm index is as expected: No eviction, everything there. EXPECT_EQ(3, index()->GetEntryCount()); - EXPECT_EQ(0, index_file()->doom_entry_set_calls()); - EXPECT_TRUE(index()->Has(kKey1Hash)); - EXPECT_TRUE(index()->Has(kKey2Hash)); - EXPECT_TRUE(index()->Has(kKey3Hash)); + EXPECT_EQ(0, doom_entries_calls()); + EXPECT_TRUE(index()->Has(hashes_.at<1>())); + EXPECT_TRUE(index()->Has(hashes_.at<2>())); + EXPECT_TRUE(index()->Has(hashes_.at<3>())); // Trigger an eviction, and make sure the right things are tossed. // TODO(rdsmith): This is dependent on the innards of the implementation // as to at exactly what point we trigger eviction. Not sure how to fix // that. - index()->UpdateEntrySize("key3", 475); - EXPECT_EQ(1, index_file()->doom_entry_set_calls()); + index()->UpdateEntrySize(hashes_.at<3>(), 475); + EXPECT_EQ(1, doom_entries_calls()); EXPECT_EQ(1, index()->GetEntryCount()); - EXPECT_FALSE(index()->Has(kKey1Hash)); - EXPECT_FALSE(index()->Has(kKey2Hash)); - EXPECT_TRUE(index()->Has(kKey3Hash)); - ASSERT_EQ(2u, index_file_->last_doom_entry_hashes().size()); + EXPECT_FALSE(index()->Has(hashes_.at<1>())); + EXPECT_FALSE(index()->Has(hashes_.at<2>())); + EXPECT_TRUE(index()->Has(hashes_.at<3>())); + ASSERT_EQ(2u, last_doom_entry_hashes().size()); } // Confirm all the operations queue a disk write at some point in the @@ -511,20 +557,21 @@ TEST_F(SimpleIndexTest, DiskWriteQueued) { EXPECT_FALSE(index()->write_to_disk_timer_.IsRunning()); - index()->Insert("key1"); + const uint64 kHash1 = hashes_.at<1>(); + index()->Insert(kHash1); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); index()->write_to_disk_timer_.Stop(); EXPECT_FALSE(index()->write_to_disk_timer_.IsRunning()); - index()->UseIfExists("key1"); + index()->UseIfExists(kHash1); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); index()->write_to_disk_timer_.Stop(); - index()->UpdateEntrySize("key1", 20); + index()->UpdateEntrySize(kHash1, 20); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); index()->write_to_disk_timer_.Stop(); - index()->Remove("key1"); + index()->Remove(kHash1); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); index()->write_to_disk_timer_.Stop(); } @@ -535,8 +582,9 @@ TEST_F(SimpleIndexTest, DiskWriteExecuted) { EXPECT_FALSE(index()->write_to_disk_timer_.IsRunning()); - index()->Insert("key1"); - index()->UpdateEntrySize("key1", 20); + const uint64 kHash1 = hashes_.at<1>(); + index()->Insert(kHash1); + index()->UpdateEntrySize(kHash1, 20); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); base::Closure user_task(index()->write_to_disk_timer_.user_task()); index()->write_to_disk_timer_.Stop(); @@ -547,14 +595,14 @@ TEST_F(SimpleIndexTest, DiskWriteExecuted) { SimpleIndex::EntrySet entry_set; index_file_->GetAndResetDiskWriteEntrySet(&entry_set); - uint64 hash_key(simple_util::GetEntryHashKey("key1")); + uint64 hash_key = kHash1; base::Time now(base::Time::Now()); ASSERT_EQ(1u, entry_set.size()); EXPECT_EQ(hash_key, entry_set.begin()->first); const EntryMetadata& entry1(entry_set.begin()->second); EXPECT_LT(now - base::TimeDelta::FromMinutes(1), entry1.GetLastUsedTime()); EXPECT_GT(now + base::TimeDelta::FromMinutes(1), entry1.GetLastUsedTime()); - EXPECT_EQ(20u, entry1.GetEntrySize()); + EXPECT_EQ(20, entry1.GetEntrySize()); } TEST_F(SimpleIndexTest, DiskWritePostponed) { @@ -563,16 +611,16 @@ TEST_F(SimpleIndexTest, DiskWritePostponed) { EXPECT_FALSE(index()->write_to_disk_timer_.IsRunning()); - index()->Insert("key1"); - index()->UpdateEntrySize("key1", 20); + index()->Insert(hashes_.at<1>()); + index()->UpdateEntrySize(hashes_.at<1>(), 20); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); base::TimeTicks expected_trigger( index()->write_to_disk_timer_.desired_run_time()); WaitForTimeChange(); EXPECT_EQ(expected_trigger, index()->write_to_disk_timer_.desired_run_time()); - index()->Insert("key2"); - index()->UpdateEntrySize("key2", 40); + index()->Insert(hashes_.at<2>()); + index()->UpdateEntrySize(hashes_.at<2>(), 40); EXPECT_TRUE(index()->write_to_disk_timer_.IsRunning()); EXPECT_LT(expected_trigger, index()->write_to_disk_timer_.desired_run_time()); index()->write_to_disk_timer_.Stop(); diff --git a/chromium/net/disk_cache/simple/simple_synchronous_entry.cc b/chromium/net/disk_cache/simple/simple_synchronous_entry.cc index e6f1eaa4f21..38e8a3cae99 100644 --- a/chromium/net/disk_cache/simple/simple_synchronous_entry.cc +++ b/chromium/net/disk_cache/simple/simple_synchronous_entry.cc @@ -14,11 +14,12 @@ #include "base/file_util.h" #include "base/hash.h" #include "base/location.h" -#include "base/metrics/histogram.h" #include "base/sha1.h" #include "base/strings/stringprintf.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" +#include "net/disk_cache/simple/simple_backend_version.h" +#include "net/disk_cache/simple/simple_histogram_macros.h" #include "net/disk_cache/simple/simple_util.h" #include "third_party/zlib/zlib.h" @@ -30,6 +31,7 @@ using base::PlatformFileError; using base::PlatformFileInfo; using base::PLATFORM_FILE_CREATE; using base::PLATFORM_FILE_ERROR_EXISTS; +using base::PLATFORM_FILE_ERROR_NOT_FOUND; using base::PLATFORM_FILE_OK; using base::PLATFORM_FILE_OPEN; using base::PLATFORM_FILE_READ; @@ -55,20 +57,14 @@ enum OpenEntryResult { }; // Used in histograms, please only add entries at the end. -enum CreateEntryResult { - CREATE_ENTRY_SUCCESS = 0, - CREATE_ENTRY_PLATFORM_FILE_ERROR = 1, - CREATE_ENTRY_CANT_WRITE_HEADER = 2, - CREATE_ENTRY_CANT_WRITE_KEY = 3, - CREATE_ENTRY_MAX = 4, -}; - -// Used in histograms, please only add entries at the end. enum WriteResult { WRITE_RESULT_SUCCESS = 0, WRITE_RESULT_PRETRUNCATE_FAILURE, WRITE_RESULT_WRITE_FAILURE, WRITE_RESULT_TRUNCATE_FAILURE, + WRITE_RESULT_LAZY_STREAM_ENTRY_DOOMED, + WRITE_RESULT_LAZY_CREATE_FAILURE, + WRITE_RESULT_LAZY_INITIALIZE_FAILURE, WRITE_RESULT_MAX, }; @@ -87,72 +83,98 @@ enum CloseResult { CLOSE_RESULT_WRITE_FAILURE, }; -void RecordSyncOpenResult(OpenEntryResult result, bool had_index) { +void RecordSyncOpenResult(net::CacheType cache_type, + OpenEntryResult result, + bool had_index) { DCHECK_GT(OPEN_ENTRY_MAX, result); - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncOpenResult", result, OPEN_ENTRY_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncOpenResult", cache_type, result, OPEN_ENTRY_MAX); if (had_index) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncOpenResult_WithIndex", result, OPEN_ENTRY_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncOpenResult_WithIndex", cache_type, + result, OPEN_ENTRY_MAX); } else { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncOpenResult_WithoutIndex", result, OPEN_ENTRY_MAX); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncOpenResult_WithoutIndex", cache_type, + result, OPEN_ENTRY_MAX); } } -void RecordSyncCreateResult(CreateEntryResult result, bool had_index) { - DCHECK_GT(CREATE_ENTRY_MAX, result); - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCreateResult", result, CREATE_ENTRY_MAX); - if (had_index) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCreateResult_WithIndex", result, CREATE_ENTRY_MAX); - } else { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCreateResult_WithoutIndex", result, CREATE_ENTRY_MAX); - } +void RecordWriteResult(net::CacheType cache_type, WriteResult result) { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncWriteResult", cache_type, result, WRITE_RESULT_MAX); } -void RecordWriteResult(WriteResult result) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncWriteResult", result, WRITE_RESULT_MAX); +void RecordCheckEOFResult(net::CacheType cache_type, CheckEOFResult result) { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCheckEOFResult", cache_type, + result, CHECK_EOF_RESULT_MAX); } -void RecordCheckEOFResult(CheckEOFResult result) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCheckEOFResult", result, CHECK_EOF_RESULT_MAX); +void RecordCloseResult(net::CacheType cache_type, CloseResult result) { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCloseResult", cache_type, result, WRITE_RESULT_MAX); } -void RecordCloseResult(CloseResult result) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCloseResult", result, WRITE_RESULT_MAX); +bool CanOmitEmptyFile(int file_index) { + DCHECK_LE(0, file_index); + DCHECK_GT(disk_cache::kSimpleEntryFileCount, file_index); + return file_index == disk_cache::simple_util::GetFileIndexFromStreamIndex(2); } } // namespace namespace disk_cache { -using simple_util::ConvertEntryHashKeyToHexString; using simple_util::GetEntryHashKey; -using simple_util::GetFilenameFromEntryHashAndIndex; +using simple_util::GetFilenameFromEntryHashAndFileIndex; using simple_util::GetDataSizeFromKeyAndFileSize; using simple_util::GetFileSizeFromKeyAndDataSize; -using simple_util::GetFileOffsetFromKeyAndDataOffset; +using simple_util::GetFileIndexFromStreamIndex; + +SimpleEntryStat::SimpleEntryStat(base::Time last_used, + base::Time last_modified, + const int32 data_size[]) + : last_used_(last_used), + last_modified_(last_modified) { + memcpy(data_size_, data_size, sizeof(data_size_)); +} + +int SimpleEntryStat::GetOffsetInFile(const std::string& key, + int offset, + int stream_index) const { + const int64 headers_size = sizeof(SimpleFileHeader) + key.size(); + const int64 additional_offset = + stream_index == 0 ? data_size_[1] + sizeof(SimpleFileEOF) : 0; + return headers_size + offset + additional_offset; +} + +int SimpleEntryStat::GetEOFOffsetInFile(const std::string& key, + int stream_index) const { + return GetOffsetInFile(key, data_size_[stream_index], stream_index); +} -SimpleEntryStat::SimpleEntryStat() {} +int SimpleEntryStat::GetLastEOFOffsetInFile(const std::string& key, + int stream_index) const { + const int file_index = GetFileIndexFromStreamIndex(stream_index); + const int eof_data_offset = + file_index == 0 ? data_size_[0] + data_size_[1] + sizeof(SimpleFileEOF) + : data_size_[2]; + return GetOffsetInFile(key, eof_data_offset, stream_index); +} -SimpleEntryStat::SimpleEntryStat(base::Time last_used_p, - base::Time last_modified_p, - const int32 data_size_p[]) - : last_used(last_used_p), - last_modified(last_modified_p) { - memcpy(data_size, data_size_p, sizeof(data_size)); +int SimpleEntryStat::GetFileSize(const std::string& key, int file_index) const { + const int total_data_size = + file_index == 0 ? data_size_[0] + data_size_[1] + sizeof(SimpleFileEOF) + : data_size_[2]; + return GetFileSizeFromKeyAndDataSize(key, total_data_size); } SimpleEntryCreationResults::SimpleEntryCreationResults( SimpleEntryStat entry_stat) : sync_entry(NULL), entry_stat(entry_stat), + stream_0_crc32(crc32(0, Z_NULL, 0)), result(net::OK) { } @@ -181,26 +203,33 @@ SimpleSynchronousEntry::EntryOperationData::EntryOperationData(int index_p, SimpleSynchronousEntry::EntryOperationData::EntryOperationData(int index_p, int offset_p, int buf_len_p, - bool truncate_p) + bool truncate_p, + bool doomed_p) : index(index_p), offset(offset_p), buf_len(buf_len_p), - truncate(truncate_p) {} + truncate(truncate_p), + doomed(doomed_p) {} // static void SimpleSynchronousEntry::OpenEntry( + net::CacheType cache_type, const FilePath& path, const uint64 entry_hash, bool had_index, SimpleEntryCreationResults *out_results) { - SimpleSynchronousEntry* sync_entry = new SimpleSynchronousEntry(path, "", - entry_hash); - out_results->result = sync_entry->InitializeForOpen( - had_index, &out_results->entry_stat); + SimpleSynchronousEntry* sync_entry = + new SimpleSynchronousEntry(cache_type, path, "", entry_hash); + out_results->result = + sync_entry->InitializeForOpen(had_index, + &out_results->entry_stat, + &out_results->stream_0_data, + &out_results->stream_0_crc32); if (out_results->result != net::OK) { sync_entry->Doom(); delete sync_entry; out_results->sync_entry = NULL; + out_results->stream_0_data = NULL; return; } out_results->sync_entry = sync_entry; @@ -208,14 +237,15 @@ void SimpleSynchronousEntry::OpenEntry( // static void SimpleSynchronousEntry::CreateEntry( + net::CacheType cache_type, const FilePath& path, const std::string& key, const uint64 entry_hash, bool had_index, SimpleEntryCreationResults *out_results) { DCHECK_EQ(entry_hash, GetEntryHashKey(key)); - SimpleSynchronousEntry* sync_entry = new SimpleSynchronousEntry(path, key, - entry_hash); + SimpleSynchronousEntry* sync_entry = + new SimpleSynchronousEntry(cache_type, path, key, entry_hash); out_results->result = sync_entry->InitializeForCreate( had_index, &out_results->entry_stat); if (out_results->result != net::OK) { @@ -228,37 +258,17 @@ void SimpleSynchronousEntry::CreateEntry( out_results->sync_entry = sync_entry; } -// TODO(gavinp): Move this function to its correct location in this .cc file. -// static -bool SimpleSynchronousEntry::DeleteFilesForEntryHash( - const FilePath& path, - const uint64 entry_hash) { - bool result = true; - for (int i = 0; i < kSimpleEntryFileCount; ++i) { - FilePath to_delete = path.AppendASCII( - GetFilenameFromEntryHashAndIndex(entry_hash, i)); - if (!base::DeleteFile(to_delete, false)) { - result = false; - DLOG(ERROR) << "Could not delete " << to_delete.MaybeAsASCII(); - } - } - return result; -} - // static -void SimpleSynchronousEntry::DoomEntry( +int SimpleSynchronousEntry::DoomEntry( const FilePath& path, - const std::string& key, - uint64 entry_hash, - int* out_result) { - DCHECK_EQ(entry_hash, GetEntryHashKey(key)); - bool deleted_well = DeleteFilesForEntryHash(path, entry_hash); - *out_result = deleted_well ? net::OK : net::ERR_FAILED; + uint64 entry_hash) { + const bool deleted_well = DeleteFilesForEntryHash(path, entry_hash); + return deleted_well ? net::OK : net::ERR_FAILED; } // static int SimpleSynchronousEntry::DoomEntrySet( - scoped_ptr<std::vector<uint64> > key_hashes, + const std::vector<uint64>* key_hashes, const FilePath& path) { const size_t did_delete_count = std::count_if( key_hashes->begin(), key_hashes->end(), std::bind1st( @@ -269,17 +279,21 @@ int SimpleSynchronousEntry::DoomEntrySet( void SimpleSynchronousEntry::ReadData(const EntryOperationData& in_entry_op, net::IOBuffer* out_buf, uint32* out_crc32, - base::Time* out_last_used, + SimpleEntryStat* entry_stat, int* out_result) const { DCHECK(initialized_); - int64 file_offset = - GetFileOffsetFromKeyAndDataOffset(key_, in_entry_op.offset); - int bytes_read = ReadPlatformFile(files_[in_entry_op.index], - file_offset, - out_buf->data(), - in_entry_op.buf_len); + DCHECK_NE(0, in_entry_op.index); + const int64 file_offset = + entry_stat->GetOffsetInFile(key_, in_entry_op.offset, in_entry_op.index); + int file_index = GetFileIndexFromStreamIndex(in_entry_op.index); + // Zero-length reads and reads to the empty streams of omitted files should + // be handled in the SimpleEntryImpl. + DCHECK_LT(0, in_entry_op.buf_len); + DCHECK(!empty_file_omitted_[file_index]); + int bytes_read = ReadPlatformFile( + files_[file_index], file_offset, out_buf->data(), in_entry_op.buf_len); if (bytes_read > 0) { - *out_last_used = Time::Now(); + entry_stat->set_last_used(Time::Now()); *out_crc32 = crc32(crc32(0L, Z_NULL, 0), reinterpret_cast<const Bytef*>(out_buf->data()), bytes_read); @@ -295,143 +309,197 @@ void SimpleSynchronousEntry::ReadData(const EntryOperationData& in_entry_op, void SimpleSynchronousEntry::WriteData(const EntryOperationData& in_entry_op, net::IOBuffer* in_buf, SimpleEntryStat* out_entry_stat, - int* out_result) const { + int* out_result) { DCHECK(initialized_); + DCHECK_NE(0, in_entry_op.index); int index = in_entry_op.index; + int file_index = GetFileIndexFromStreamIndex(index); int offset = in_entry_op.offset; int buf_len = in_entry_op.buf_len; - int truncate = in_entry_op.truncate; + bool truncate = in_entry_op.truncate; + bool doomed = in_entry_op.doomed; + const int64 file_offset = out_entry_stat->GetOffsetInFile( + key_, in_entry_op.offset, in_entry_op.index); + bool extending_by_write = offset + buf_len > out_entry_stat->data_size(index); + + if (empty_file_omitted_[file_index]) { + // Don't create a new file if the entry has been doomed, to avoid it being + // mixed up with a newly-created entry with the same key. + if (doomed) { + DLOG(WARNING) << "Rejecting write to lazily omitted stream " + << in_entry_op.index << " of doomed cache entry."; + RecordWriteResult(cache_type_, WRITE_RESULT_LAZY_STREAM_ENTRY_DOOMED); + *out_result = net::ERR_CACHE_WRITE_FAILURE; + return; + } + PlatformFileError error; + if (!MaybeCreateFile(file_index, FILE_REQUIRED, &error)) { + RecordWriteResult(cache_type_, WRITE_RESULT_LAZY_CREATE_FAILURE); + Doom(); + *out_result = net::ERR_CACHE_WRITE_FAILURE; + return; + } + CreateEntryResult result; + if (!InitializeCreatedFile(file_index, &result)) { + RecordWriteResult(cache_type_, WRITE_RESULT_LAZY_INITIALIZE_FAILURE); + Doom(); + *out_result = net::ERR_CACHE_WRITE_FAILURE; + return; + } + } + DCHECK(!empty_file_omitted_[file_index]); - bool extending_by_write = offset + buf_len > out_entry_stat->data_size[index]; if (extending_by_write) { - // We are extending the file, and need to insure the EOF record is zeroed. - const int64 file_eof_offset = GetFileOffsetFromKeyAndDataOffset( - key_, out_entry_stat->data_size[index]); - if (!TruncatePlatformFile(files_[index], file_eof_offset)) { - RecordWriteResult(WRITE_RESULT_PRETRUNCATE_FAILURE); + // The EOF record and the eventual stream afterward need to be zeroed out. + const int64 file_eof_offset = + out_entry_stat->GetEOFOffsetInFile(key_, index); + if (!TruncatePlatformFile(files_[file_index], file_eof_offset)) { + RecordWriteResult(cache_type_, WRITE_RESULT_PRETRUNCATE_FAILURE); Doom(); *out_result = net::ERR_CACHE_WRITE_FAILURE; return; } } - const int64 file_offset = GetFileOffsetFromKeyAndDataOffset(key_, offset); if (buf_len > 0) { if (WritePlatformFile( - files_[index], file_offset, in_buf->data(), buf_len) != buf_len) { - RecordWriteResult(WRITE_RESULT_WRITE_FAILURE); + files_[file_index], file_offset, in_buf->data(), buf_len) != + buf_len) { + RecordWriteResult(cache_type_, WRITE_RESULT_WRITE_FAILURE); Doom(); *out_result = net::ERR_CACHE_WRITE_FAILURE; return; } } if (!truncate && (buf_len > 0 || !extending_by_write)) { - out_entry_stat->data_size[index] = - std::max(out_entry_stat->data_size[index], offset + buf_len); + out_entry_stat->set_data_size( + index, std::max(out_entry_stat->data_size(index), offset + buf_len)); } else { - if (!TruncatePlatformFile(files_[index], file_offset + buf_len)) { - RecordWriteResult(WRITE_RESULT_TRUNCATE_FAILURE); + out_entry_stat->set_data_size(index, offset + buf_len); + int file_eof_offset = out_entry_stat->GetLastEOFOffsetInFile(key_, index); + if (!TruncatePlatformFile(files_[file_index], file_eof_offset)) { + RecordWriteResult(cache_type_, WRITE_RESULT_TRUNCATE_FAILURE); Doom(); *out_result = net::ERR_CACHE_WRITE_FAILURE; return; } - out_entry_stat->data_size[index] = offset + buf_len; } - RecordWriteResult(WRITE_RESULT_SUCCESS); - out_entry_stat->last_used = out_entry_stat->last_modified = Time::Now(); + RecordWriteResult(cache_type_, WRITE_RESULT_SUCCESS); + base::Time modification_time = Time::Now(); + out_entry_stat->set_last_used(modification_time); + out_entry_stat->set_last_modified(modification_time); *out_result = buf_len; } void SimpleSynchronousEntry::CheckEOFRecord(int index, - int32 data_size, + const SimpleEntryStat& entry_stat, uint32 expected_crc32, int* out_result) const { DCHECK(initialized_); - - SimpleFileEOF eof_record; - int64 file_offset = GetFileOffsetFromKeyAndDataOffset(key_, data_size); - if (ReadPlatformFile(files_[index], - file_offset, - reinterpret_cast<char*>(&eof_record), - sizeof(eof_record)) != sizeof(eof_record)) { - RecordCheckEOFResult(CHECK_EOF_RESULT_READ_FAILURE); + uint32 crc32; + bool has_crc32; + int stream_size; + *out_result = + GetEOFRecordData(index, entry_stat, &has_crc32, &crc32, &stream_size); + if (*out_result != net::OK) { Doom(); - *out_result = net::ERR_CACHE_CHECKSUM_READ_FAILURE; return; } - - if (eof_record.final_magic_number != kSimpleFinalMagicNumber) { - RecordCheckEOFResult(CHECK_EOF_RESULT_MAGIC_NUMBER_MISMATCH); - DLOG(INFO) << "eof record had bad magic number."; - Doom(); - *out_result = net::ERR_CACHE_CHECKSUM_READ_FAILURE; - return; - } - - const bool has_crc = (eof_record.flags & SimpleFileEOF::FLAG_HAS_CRC32) == - SimpleFileEOF::FLAG_HAS_CRC32; - UMA_HISTOGRAM_BOOLEAN("SimpleCache.SyncCheckEOFHasCrc", has_crc); - if (has_crc && eof_record.data_crc32 != expected_crc32) { - RecordCheckEOFResult(CHECK_EOF_RESULT_CRC_MISMATCH); - DLOG(INFO) << "eof record had bad crc."; - Doom(); + if (has_crc32 && crc32 != expected_crc32) { + DLOG(INFO) << "EOF record had bad crc."; *out_result = net::ERR_CACHE_CHECKSUM_MISMATCH; + RecordCheckEOFResult(cache_type_, CHECK_EOF_RESULT_CRC_MISMATCH); + Doom(); return; } - - RecordCheckEOFResult(CHECK_EOF_RESULT_SUCCESS); - *out_result = net::OK; + RecordCheckEOFResult(cache_type_, CHECK_EOF_RESULT_SUCCESS); } void SimpleSynchronousEntry::Close( const SimpleEntryStat& entry_stat, - scoped_ptr<std::vector<CRCRecord> > crc32s_to_write) { + scoped_ptr<std::vector<CRCRecord> > crc32s_to_write, + net::GrowableIOBuffer* stream_0_data) { + DCHECK(stream_0_data); + // Write stream 0 data. + int stream_0_offset = entry_stat.GetOffsetInFile(key_, 0, 0); + if (WritePlatformFile(files_[0], + stream_0_offset, + stream_0_data->data(), + entry_stat.data_size(0)) != entry_stat.data_size(0)) { + RecordCloseResult(cache_type_, CLOSE_RESULT_WRITE_FAILURE); + DLOG(INFO) << "Could not write stream 0 data."; + Doom(); + } + for (std::vector<CRCRecord>::const_iterator it = crc32s_to_write->begin(); it != crc32s_to_write->end(); ++it) { + const int stream_index = it->index; + const int file_index = GetFileIndexFromStreamIndex(stream_index); + if (empty_file_omitted_[file_index]) + continue; + SimpleFileEOF eof_record; + eof_record.stream_size = entry_stat.data_size(stream_index); eof_record.final_magic_number = kSimpleFinalMagicNumber; eof_record.flags = 0; if (it->has_crc32) eof_record.flags |= SimpleFileEOF::FLAG_HAS_CRC32; eof_record.data_crc32 = it->data_crc32; - int64 file_offset = GetFileOffsetFromKeyAndDataOffset( - key_, entry_stat.data_size[it->index]); - if (WritePlatformFile(files_[it->index], - file_offset, + int eof_offset = entry_stat.GetEOFOffsetInFile(key_, stream_index); + // If stream 0 changed size, the file needs to be resized, otherwise the + // next open will yield wrong stream sizes. On stream 1 and stream 2 proper + // resizing of the file is handled in SimpleSynchronousEntry::WriteData(). + if (stream_index == 0 && + !TruncatePlatformFile(files_[file_index], eof_offset)) { + RecordCloseResult(cache_type_, CLOSE_RESULT_WRITE_FAILURE); + DLOG(INFO) << "Could not truncate stream 0 file."; + Doom(); + break; + } + if (WritePlatformFile(files_[file_index], + eof_offset, reinterpret_cast<const char*>(&eof_record), sizeof(eof_record)) != sizeof(eof_record)) { - RecordCloseResult(CLOSE_RESULT_WRITE_FAILURE); + RecordCloseResult(cache_type_, CLOSE_RESULT_WRITE_FAILURE); DLOG(INFO) << "Could not write eof record."; Doom(); break; } - const int64 file_size = file_offset + sizeof(eof_record); - UMA_HISTOGRAM_CUSTOM_COUNTS("SimpleCache.LastClusterSize", - file_size % 4096, 0, 4097, 50); - const int64 cluster_loss = file_size % 4096 ? 4096 - file_size % 4096 : 0; - UMA_HISTOGRAM_PERCENTAGE("SimpleCache.LastClusterLossPercent", - cluster_loss * 100 / (cluster_loss + file_size)); } - for (int i = 0; i < kSimpleEntryFileCount; ++i) { + if (empty_file_omitted_[i]) + continue; + bool did_close_file = ClosePlatformFile(files_[i]); - CHECK(did_close_file); + DCHECK(did_close_file); + const int64 file_size = entry_stat.GetFileSize(key_, i); + SIMPLE_CACHE_UMA(CUSTOM_COUNTS, + "LastClusterSize", cache_type_, + file_size % 4096, 0, 4097, 50); + const int64 cluster_loss = file_size % 4096 ? 4096 - file_size % 4096 : 0; + SIMPLE_CACHE_UMA(PERCENTAGE, + "LastClusterLossPercent", cache_type_, + cluster_loss * 100 / (cluster_loss + file_size)); } - RecordCloseResult(CLOSE_RESULT_SUCCESS); + + RecordCloseResult(cache_type_, CLOSE_RESULT_SUCCESS); have_open_files_ = false; delete this; } -SimpleSynchronousEntry::SimpleSynchronousEntry(const FilePath& path, +SimpleSynchronousEntry::SimpleSynchronousEntry(net::CacheType cache_type, + const FilePath& path, const std::string& key, const uint64 entry_hash) - : path_(path), + : cache_type_(cache_type), + path_(path), entry_hash_(entry_hash), key_(key), have_open_files_(false), initialized_(false) { for (int i = 0; i < kSimpleEntryFileCount; ++i) { files_[i] = kInvalidPlatformFileValue; + empty_file_omitted_[i] = false; } } @@ -441,110 +509,198 @@ SimpleSynchronousEntry::~SimpleSynchronousEntry() { CloseFiles(); } -bool SimpleSynchronousEntry::OpenOrCreateFiles( - bool create, +bool SimpleSynchronousEntry::MaybeOpenFile( + int file_index, + PlatformFileError* out_error) { + DCHECK(out_error); + + FilePath filename = GetFilenameFromFileIndex(file_index); + int flags = PLATFORM_FILE_OPEN | PLATFORM_FILE_READ | PLATFORM_FILE_WRITE; + files_[file_index] = CreatePlatformFile(filename, flags, NULL, out_error); + + if (CanOmitEmptyFile(file_index) && + *out_error == PLATFORM_FILE_ERROR_NOT_FOUND) { + empty_file_omitted_[file_index] = true; + return true; + } + + return *out_error == PLATFORM_FILE_OK; +} + +bool SimpleSynchronousEntry::MaybeCreateFile( + int file_index, + FileRequired file_required, + PlatformFileError* out_error) { + DCHECK(out_error); + + if (CanOmitEmptyFile(file_index) && file_required == FILE_NOT_REQUIRED) { + empty_file_omitted_[file_index] = true; + return true; + } + + FilePath filename = GetFilenameFromFileIndex(file_index); + int flags = PLATFORM_FILE_CREATE | PLATFORM_FILE_READ | PLATFORM_FILE_WRITE; + files_[file_index] = CreatePlatformFile(filename, flags, NULL, out_error); + + empty_file_omitted_[file_index] = false; + + return *out_error == PLATFORM_FILE_OK; +} + +bool SimpleSynchronousEntry::OpenFiles( bool had_index, SimpleEntryStat* out_entry_stat) { for (int i = 0; i < kSimpleEntryFileCount; ++i) { - FilePath filename = path_.AppendASCII( - GetFilenameFromEntryHashAndIndex(entry_hash_, i)); - int flags = PLATFORM_FILE_READ | PLATFORM_FILE_WRITE; - if (create) - flags |= PLATFORM_FILE_CREATE; - else - flags |= PLATFORM_FILE_OPEN; PlatformFileError error; - files_[i] = CreatePlatformFile(filename, flags, NULL, &error); - if (error != PLATFORM_FILE_OK) { + if (!MaybeOpenFile(i, &error)) { // TODO(ttuttle,gavinp): Remove one each of these triplets of histograms. // We can calculate the third as the sum or difference of the other two. - if (create) { - RecordSyncCreateResult(CREATE_ENTRY_PLATFORM_FILE_ERROR, had_index); - UMA_HISTOGRAM_ENUMERATION("SimpleCache.SyncCreatePlatformFileError", - -error, -base::PLATFORM_FILE_ERROR_MAX); - if (had_index) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCreatePlatformFileError_WithIndex", - -error, -base::PLATFORM_FILE_ERROR_MAX); - } else { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncCreatePlatformFileError_WithoutIndex", - -error, -base::PLATFORM_FILE_ERROR_MAX); - } + RecordSyncOpenResult( + cache_type_, OPEN_ENTRY_PLATFORM_FILE_ERROR, had_index); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncOpenPlatformFileError", cache_type_, + -error, -base::PLATFORM_FILE_ERROR_MAX); + if (had_index) { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncOpenPlatformFileError_WithIndex", cache_type_, + -error, -base::PLATFORM_FILE_ERROR_MAX); } else { - RecordSyncOpenResult(OPEN_ENTRY_PLATFORM_FILE_ERROR, had_index); - UMA_HISTOGRAM_ENUMERATION("SimpleCache.SyncOpenPlatformFileError", - -error, -base::PLATFORM_FILE_ERROR_MAX); - if (had_index) { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncOpenPlatformFileError_WithIndex", - -error, -base::PLATFORM_FILE_ERROR_MAX); - } else { - UMA_HISTOGRAM_ENUMERATION( - "SimpleCache.SyncOpenPlatformFileError_WithoutIndex", - -error, -base::PLATFORM_FILE_ERROR_MAX); - } - } - while (--i >= 0) { - bool ALLOW_UNUSED did_close = ClosePlatformFile(files_[i]); - DLOG_IF(INFO, !did_close) << "Could not close file " - << filename.MaybeAsASCII(); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncOpenPlatformFileError_WithoutIndex", + cache_type_, + -error, -base::PLATFORM_FILE_ERROR_MAX); } + while (--i >= 0) + CloseFile(i); return false; } } have_open_files_ = true; - if (create) { - out_entry_stat->last_modified = out_entry_stat->last_used = Time::Now(); - for (int i = 0; i < kSimpleEntryFileCount; ++i) - out_entry_stat->data_size[i] = 0; - } else { - for (int i = 0; i < kSimpleEntryFileCount; ++i) { - PlatformFileInfo file_info; - bool success = GetPlatformFileInfo(files_[i], &file_info); - base::Time file_last_modified; - if (!success) { - DLOG(WARNING) << "Could not get platform file info."; - continue; - } - out_entry_stat->last_used = file_info.last_accessed; - if (simple_util::GetMTime(path_, &file_last_modified)) - out_entry_stat->last_modified = file_last_modified; - else - out_entry_stat->last_modified = file_info.last_modified; - - // Keep the file size in |data size_| briefly until the key is initialized - // properly. - out_entry_stat->data_size[i] = file_info.size; + + base::TimeDelta entry_age = base::Time::Now() - base::Time::UnixEpoch(); + for (int i = 0; i < kSimpleEntryFileCount; ++i) { + if (empty_file_omitted_[i]) { + out_entry_stat->set_data_size(i + 1, 0); + continue; + } + + PlatformFileInfo file_info; + bool success = GetPlatformFileInfo(files_[i], &file_info); + base::Time file_last_modified; + if (!success) { + DLOG(WARNING) << "Could not get platform file info."; + continue; } + out_entry_stat->set_last_used(file_info.last_accessed); + if (simple_util::GetMTime(path_, &file_last_modified)) + out_entry_stat->set_last_modified(file_last_modified); + else + out_entry_stat->set_last_modified(file_info.last_modified); + + base::TimeDelta stream_age = + base::Time::Now() - out_entry_stat->last_modified(); + if (stream_age < entry_age) + entry_age = stream_age; + + // Two things prevent from knowing the right values for |data_size|: + // 1) The key is not known, hence its length is unknown. + // 2) Stream 0 and stream 1 are in the same file, and the exact size for + // each will only be known when reading the EOF record for stream 0. + // + // The size for file 0 and 1 is temporarily kept in + // |data_size(1)| and |data_size(2)| respectively. Reading the key in + // InitializeForOpen yields the data size for each file. In the case of + // file hash_1, this is the total size of stream 2, and is assigned to + // data_size(2). In the case of file 0, it is the combined size of stream + // 0, stream 1 and one EOF record. The exact distribution of sizes between + // stream 1 and stream 0 is only determined after reading the EOF record + // for stream 0 in ReadAndValidateStream0. + out_entry_stat->set_data_size(i + 1, file_info.size); } + SIMPLE_CACHE_UMA(CUSTOM_COUNTS, + "SyncOpenEntryAge", cache_type_, + entry_age.InHours(), 1, 1000, 50); return true; } -void SimpleSynchronousEntry::CloseFiles() { +bool SimpleSynchronousEntry::CreateFiles( + bool had_index, + SimpleEntryStat* out_entry_stat) { for (int i = 0; i < kSimpleEntryFileCount; ++i) { - DCHECK_NE(kInvalidPlatformFileValue, files_[i]); - bool did_close = ClosePlatformFile(files_[i]); + PlatformFileError error; + if (!MaybeCreateFile(i, FILE_NOT_REQUIRED, &error)) { + // TODO(ttuttle,gavinp): Remove one each of these triplets of histograms. + // We can calculate the third as the sum or difference of the other two. + RecordSyncCreateResult(CREATE_ENTRY_PLATFORM_FILE_ERROR, had_index); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCreatePlatformFileError", cache_type_, + -error, -base::PLATFORM_FILE_ERROR_MAX); + if (had_index) { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCreatePlatformFileError_WithIndex", cache_type_, + -error, -base::PLATFORM_FILE_ERROR_MAX); + } else { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCreatePlatformFileError_WithoutIndex", + cache_type_, + -error, -base::PLATFORM_FILE_ERROR_MAX); + } + while (--i >= 0) + CloseFile(i); + return false; + } + } + + have_open_files_ = true; + + base::Time creation_time = Time::Now(); + out_entry_stat->set_last_modified(creation_time); + out_entry_stat->set_last_used(creation_time); + for (int i = 0; i < kSimpleEntryStreamCount; ++i) + out_entry_stat->set_data_size(i, 0); + + return true; +} + +void SimpleSynchronousEntry::CloseFile(int index) { + if (empty_file_omitted_[index]) { + empty_file_omitted_[index] = false; + } else { + DCHECK_NE(kInvalidPlatformFileValue, files_[index]); + bool did_close = ClosePlatformFile(files_[index]); DCHECK(did_close); + files_[index] = kInvalidPlatformFileValue; } } -int SimpleSynchronousEntry::InitializeForOpen(bool had_index, - SimpleEntryStat* out_entry_stat) { +void SimpleSynchronousEntry::CloseFiles() { + for (int i = 0; i < kSimpleEntryFileCount; ++i) + CloseFile(i); +} + +int SimpleSynchronousEntry::InitializeForOpen( + bool had_index, + SimpleEntryStat* out_entry_stat, + scoped_refptr<net::GrowableIOBuffer>* stream_0_data, + uint32* out_stream_0_crc32) { DCHECK(!initialized_); - if (!OpenOrCreateFiles(false, had_index, out_entry_stat)) + if (!OpenFiles(had_index, out_entry_stat)) { + DLOG(WARNING) << "Could not open platform files for entry."; return net::ERR_FAILED; - + } for (int i = 0; i < kSimpleEntryFileCount; ++i) { + if (empty_file_omitted_[i]) + continue; + SimpleFileHeader header; int header_read_result = ReadPlatformFile(files_[i], 0, reinterpret_cast<char*>(&header), sizeof(header)); if (header_read_result != sizeof(header)) { DLOG(WARNING) << "Cannot read header from entry."; - RecordSyncOpenResult(OPEN_ENTRY_CANT_READ_HEADER, had_index); + RecordSyncOpenResult(cache_type_, OPEN_ENTRY_CANT_READ_HEADER, had_index); return net::ERR_FAILED; } @@ -553,13 +709,13 @@ int SimpleSynchronousEntry::InitializeForOpen(bool had_index, // should give consideration to not saturating the log with these if that // becomes a problem. DLOG(WARNING) << "Magic number did not match."; - RecordSyncOpenResult(OPEN_ENTRY_BAD_MAGIC_NUMBER, had_index); + RecordSyncOpenResult(cache_type_, OPEN_ENTRY_BAD_MAGIC_NUMBER, had_index); return net::ERR_FAILED; } - if (header.version != kSimpleVersion) { + if (header.version != kSimpleEntryVersionOnDisk) { DLOG(WARNING) << "Unreadable version."; - RecordSyncOpenResult(OPEN_ENTRY_BAD_VERSION, had_index); + RecordSyncOpenResult(cache_type_, OPEN_ENTRY_BAD_VERSION, had_index); return net::ERR_FAILED; } @@ -568,57 +724,93 @@ int SimpleSynchronousEntry::InitializeForOpen(bool had_index, key.get(), header.key_length); if (key_read_result != implicit_cast<int>(header.key_length)) { DLOG(WARNING) << "Cannot read key from entry."; - RecordSyncOpenResult(OPEN_ENTRY_CANT_READ_KEY, had_index); + RecordSyncOpenResult(cache_type_, OPEN_ENTRY_CANT_READ_KEY, had_index); return net::ERR_FAILED; } key_ = std::string(key.get(), header.key_length); - out_entry_stat->data_size[i] = - GetDataSizeFromKeyAndFileSize(key_, out_entry_stat->data_size[i]); - if (out_entry_stat->data_size[i] < 0) { - // This entry can't possibly be valid, as it does not have enough space to - // store a valid SimpleFileEOF record. - return net::ERR_FAILED; + if (i == 0) { + // File size for stream 0 has been stored temporarily in data_size[1]. + int total_data_size = + GetDataSizeFromKeyAndFileSize(key_, out_entry_stat->data_size(1)); + int ret_value_stream_0 = ReadAndValidateStream0( + total_data_size, out_entry_stat, stream_0_data, out_stream_0_crc32); + if (ret_value_stream_0 != net::OK) + return ret_value_stream_0; + } else { + out_entry_stat->set_data_size( + 2, GetDataSizeFromKeyAndFileSize(key_, out_entry_stat->data_size(2))); + if (out_entry_stat->data_size(2) < 0) { + DLOG(WARNING) << "Stream 2 file is too small."; + return net::ERR_FAILED; + } } if (base::Hash(key.get(), header.key_length) != header.key_hash) { DLOG(WARNING) << "Hash mismatch on key."; - RecordSyncOpenResult(OPEN_ENTRY_KEY_HASH_MISMATCH, had_index); + RecordSyncOpenResult( + cache_type_, OPEN_ENTRY_KEY_HASH_MISMATCH, had_index); return net::ERR_FAILED; } } - RecordSyncOpenResult(OPEN_ENTRY_SUCCESS, had_index); + + const int third_stream_file_index = GetFileIndexFromStreamIndex(2); + DCHECK(CanOmitEmptyFile(third_stream_file_index)); + if (!empty_file_omitted_[third_stream_file_index] && + out_entry_stat->data_size(2) == 0) { + DLOG(INFO) << "Removing empty third stream file."; + CloseFile(third_stream_file_index); + DeleteFileForEntryHash(path_, entry_hash_, third_stream_file_index); + empty_file_omitted_[third_stream_file_index] = true; + } + + RecordSyncOpenResult(cache_type_, OPEN_ENTRY_SUCCESS, had_index); initialized_ = true; return net::OK; } +bool SimpleSynchronousEntry::InitializeCreatedFile( + int file_index, + CreateEntryResult* out_result) { + SimpleFileHeader header; + header.initial_magic_number = kSimpleInitialMagicNumber; + header.version = kSimpleEntryVersionOnDisk; + + header.key_length = key_.size(); + header.key_hash = base::Hash(key_); + + int bytes_written = WritePlatformFile( + files_[file_index], 0, reinterpret_cast<char*>(&header), sizeof(header)); + if (bytes_written != sizeof(header)) { + *out_result = CREATE_ENTRY_CANT_WRITE_HEADER; + return false; + } + + bytes_written = WritePlatformFile( + files_[file_index], sizeof(header), key_.data(), key_.size()); + if (bytes_written != implicit_cast<int>(key_.size())) { + *out_result = CREATE_ENTRY_CANT_WRITE_KEY; + return false; + } + + return true; +} + int SimpleSynchronousEntry::InitializeForCreate( bool had_index, SimpleEntryStat* out_entry_stat) { DCHECK(!initialized_); - if (!OpenOrCreateFiles(true, had_index, out_entry_stat)) { + if (!CreateFiles(had_index, out_entry_stat)) { DLOG(WARNING) << "Could not create platform files."; return net::ERR_FILE_EXISTS; } for (int i = 0; i < kSimpleEntryFileCount; ++i) { - SimpleFileHeader header; - header.initial_magic_number = kSimpleInitialMagicNumber; - header.version = kSimpleVersion; - - header.key_length = key_.size(); - header.key_hash = base::Hash(key_); - - if (WritePlatformFile(files_[i], 0, reinterpret_cast<char*>(&header), - sizeof(header)) != sizeof(header)) { - DLOG(WARNING) << "Could not write headers to new cache entry."; - RecordSyncCreateResult(CREATE_ENTRY_CANT_WRITE_HEADER, had_index); - return net::ERR_FAILED; - } + if (empty_file_omitted_[i]) + continue; - if (WritePlatformFile(files_[i], sizeof(header), key_.data(), - key_.size()) != implicit_cast<int>(key_.size())) { - DLOG(WARNING) << "Could not write keys to new cache entry."; - RecordSyncCreateResult(CREATE_ENTRY_CANT_WRITE_KEY, had_index); + CreateEntryResult result; + if (!InitializeCreatedFile(i, &result)) { + RecordSyncCreateResult(result, had_index); return net::ERR_FAILED; } } @@ -627,9 +819,131 @@ int SimpleSynchronousEntry::InitializeForCreate( return net::OK; } +int SimpleSynchronousEntry::ReadAndValidateStream0( + int total_data_size, + SimpleEntryStat* out_entry_stat, + scoped_refptr<net::GrowableIOBuffer>* stream_0_data, + uint32* out_stream_0_crc32) const { + // Temporarily assign all the data size to stream 1 in order to read the + // EOF record for stream 0, which contains the size of stream 0. + out_entry_stat->set_data_size(0, 0); + out_entry_stat->set_data_size(1, total_data_size - sizeof(SimpleFileEOF)); + + bool has_crc32; + uint32 read_crc32; + int stream_0_size; + int ret_value_crc32 = GetEOFRecordData( + 0, *out_entry_stat, &has_crc32, &read_crc32, &stream_0_size); + if (ret_value_crc32 != net::OK) + return ret_value_crc32; + + if (stream_0_size > out_entry_stat->data_size(1)) + return net::ERR_FAILED; + + // These are the real values of data size. + out_entry_stat->set_data_size(0, stream_0_size); + out_entry_stat->set_data_size( + 1, out_entry_stat->data_size(1) - stream_0_size); + + // Put stream 0 data in memory. + *stream_0_data = new net::GrowableIOBuffer(); + (*stream_0_data)->SetCapacity(stream_0_size); + int file_offset = out_entry_stat->GetOffsetInFile(key_, 0, 0); + int bytes_read = ReadPlatformFile( + files_[0], file_offset, (*stream_0_data)->data(), stream_0_size); + if (bytes_read != stream_0_size) + return net::ERR_FAILED; + + // Check the CRC32. + uint32 expected_crc32 = + stream_0_size == 0 + ? crc32(0, Z_NULL, 0) + : crc32(crc32(0, Z_NULL, 0), + reinterpret_cast<const Bytef*>((*stream_0_data)->data()), + stream_0_size); + if (has_crc32 && read_crc32 != expected_crc32) { + DLOG(INFO) << "EOF record had bad crc."; + RecordCheckEOFResult(cache_type_, CHECK_EOF_RESULT_CRC_MISMATCH); + return net::ERR_FAILED; + } + *out_stream_0_crc32 = expected_crc32; + RecordCheckEOFResult(cache_type_, CHECK_EOF_RESULT_SUCCESS); + return net::OK; +} + +int SimpleSynchronousEntry::GetEOFRecordData(int index, + const SimpleEntryStat& entry_stat, + bool* out_has_crc32, + uint32* out_crc32, + int* out_data_size) const { + SimpleFileEOF eof_record; + int file_offset = entry_stat.GetEOFOffsetInFile(key_, index); + int file_index = GetFileIndexFromStreamIndex(index); + if (ReadPlatformFile(files_[file_index], + file_offset, + reinterpret_cast<char*>(&eof_record), + sizeof(eof_record)) != sizeof(eof_record)) { + RecordCheckEOFResult(cache_type_, CHECK_EOF_RESULT_READ_FAILURE); + return net::ERR_CACHE_CHECKSUM_READ_FAILURE; + } + + if (eof_record.final_magic_number != kSimpleFinalMagicNumber) { + RecordCheckEOFResult(cache_type_, CHECK_EOF_RESULT_MAGIC_NUMBER_MISMATCH); + DLOG(INFO) << "EOF record had bad magic number."; + return net::ERR_CACHE_CHECKSUM_READ_FAILURE; + } + + *out_has_crc32 = (eof_record.flags & SimpleFileEOF::FLAG_HAS_CRC32) == + SimpleFileEOF::FLAG_HAS_CRC32; + *out_crc32 = eof_record.data_crc32; + *out_data_size = eof_record.stream_size; + SIMPLE_CACHE_UMA(BOOLEAN, "SyncCheckEOFHasCrc", cache_type_, *out_has_crc32); + return net::OK; +} + void SimpleSynchronousEntry::Doom() const { - // TODO(gavinp): Consider if we should guard against redundant Doom() calls. DeleteFilesForEntryHash(path_, entry_hash_); } +bool SimpleSynchronousEntry::DeleteFileForEntryHash( + const FilePath& path, + const uint64 entry_hash, + const int file_index) { + FilePath to_delete = path.AppendASCII( + GetFilenameFromEntryHashAndFileIndex(entry_hash, file_index)); + return base::DeleteFile(to_delete, false); +} + +bool SimpleSynchronousEntry::DeleteFilesForEntryHash( + const FilePath& path, + const uint64 entry_hash) { + bool result = true; + for (int i = 0; i < kSimpleEntryFileCount; ++i) { + if (!DeleteFileForEntryHash(path, entry_hash, i) && !CanOmitEmptyFile(i)) + result = false; + } + return result; +} + +void SimpleSynchronousEntry::RecordSyncCreateResult(CreateEntryResult result, + bool had_index) { + DCHECK_GT(CREATE_ENTRY_MAX, result); + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCreateResult", cache_type_, result, CREATE_ENTRY_MAX); + if (had_index) { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCreateResult_WithIndex", cache_type_, + result, CREATE_ENTRY_MAX); + } else { + SIMPLE_CACHE_UMA(ENUMERATION, + "SyncCreateResult_WithoutIndex", cache_type_, + result, CREATE_ENTRY_MAX); + } +} + +FilePath SimpleSynchronousEntry::GetFilenameFromFileIndex(int file_index) { + return path_.AppendASCII( + GetFilenameFromEntryHashAndFileIndex(entry_hash_, file_index)); +} + } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_synchronous_entry.h b/chromium/net/disk_cache/simple/simple_synchronous_entry.h index f591c9a0f93..470e8e20bd8 100644 --- a/chromium/net/disk_cache/simple/simple_synchronous_entry.h +++ b/chromium/net/disk_cache/simple/simple_synchronous_entry.h @@ -11,12 +11,16 @@ #include <vector> #include "base/files/file_path.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/platform_file.h" #include "base/time/time.h" +#include "net/base/cache_type.h" +#include "net/base/net_export.h" #include "net/disk_cache/simple/simple_entry_format.h" namespace net { +class GrowableIOBuffer; class IOBuffer; } @@ -24,23 +28,48 @@ namespace disk_cache { class SimpleSynchronousEntry; -struct SimpleEntryStat { - SimpleEntryStat(); - SimpleEntryStat(base::Time last_used_p, - base::Time last_modified_p, - const int32 data_size_p[]); +// This class handles the passing of data about the entry between +// SimpleEntryImplementation and SimpleSynchronousEntry and the computation of +// file offsets based on the data size for all streams. +class NET_EXPORT_PRIVATE SimpleEntryStat { + public: + SimpleEntryStat(base::Time last_used, + base::Time last_modified, + const int32 data_size[]); + + int GetOffsetInFile(const std::string& key, + int offset, + int stream_index) const; + int GetEOFOffsetInFile(const std::string& key, int stream_index) const; + int GetLastEOFOffsetInFile(const std::string& key, int file_index) const; + int GetFileSize(const std::string& key, int file_index) const; + + base::Time last_used() const { return last_used_; } + base::Time last_modified() const { return last_modified_; } + void set_last_used(base::Time last_used) { last_used_ = last_used; } + void set_last_modified(base::Time last_modified) { + last_modified_ = last_modified; + } + + int32 data_size(int stream_index) const { return data_size_[stream_index]; } + void set_data_size(int stream_index, int data_size) { + data_size_[stream_index] = data_size; + } - base::Time last_used; - base::Time last_modified; - int32 data_size[kSimpleEntryFileCount]; + private: + base::Time last_used_; + base::Time last_modified_; + int32 data_size_[kSimpleEntryStreamCount]; }; struct SimpleEntryCreationResults { - SimpleEntryCreationResults(SimpleEntryStat entry_stat); + explicit SimpleEntryCreationResults(SimpleEntryStat entry_stat); ~SimpleEntryCreationResults(); SimpleSynchronousEntry* sync_entry; + scoped_refptr<net::GrowableIOBuffer> stream_0_data; SimpleEntryStat entry_stat; + uint32 stream_0_crc32; int result; }; @@ -63,65 +92,81 @@ class SimpleSynchronousEntry { EntryOperationData(int index_p, int offset_p, int buf_len_p, - bool truncate_p); + bool truncate_p, + bool doomed_p); int index; int offset; int buf_len; bool truncate; + bool doomed; }; - static void OpenEntry(const base::FilePath& path, + static void OpenEntry(net::CacheType cache_type, + const base::FilePath& path, uint64 entry_hash, bool had_index, SimpleEntryCreationResults* out_results); - static void CreateEntry(const base::FilePath& path, + static void CreateEntry(net::CacheType cache_type, + const base::FilePath& path, const std::string& key, uint64 entry_hash, bool had_index, SimpleEntryCreationResults* out_results); - // Deletes an entry without first Opening it. Does not check if there is - // already an Entry object in memory holding the open files. Be careful! This - // is meant to be used by the Backend::DoomEntry() call. |callback| will be - // run by |callback_runner|. - static void DoomEntry(const base::FilePath& path, - const std::string& key, - uint64 entry_hash, - int* out_result); + // Deletes an entry from the file system without affecting the state of the + // corresponding instance, if any (allowing operations to continue to be + // executed through that instance). Returns a net error code. + static int DoomEntry(const base::FilePath& path, + uint64 entry_hash); // Like |DoomEntry()| above. Deletes all entries corresponding to the // |key_hashes|. Succeeds only when all entries are deleted. Returns a net // error code. - static int DoomEntrySet(scoped_ptr<std::vector<uint64> > key_hashes, + static int DoomEntrySet(const std::vector<uint64>* key_hashes, const base::FilePath& path); // N.B. ReadData(), WriteData(), CheckEOFRecord() and Close() may block on IO. void ReadData(const EntryOperationData& in_entry_op, net::IOBuffer* out_buf, uint32* out_crc32, - base::Time* out_last_used, + SimpleEntryStat* entry_stat, int* out_result) const; void WriteData(const EntryOperationData& in_entry_op, net::IOBuffer* in_buf, SimpleEntryStat* out_entry_stat, - int* out_result) const; + int* out_result); void CheckEOFRecord(int index, - int data_size, + const SimpleEntryStat& entry_stat, uint32 expected_crc32, int* out_result) const; // Close all streams, and add write EOF records to streams indicated by the // CRCRecord entries in |crc32s_to_write|. void Close(const SimpleEntryStat& entry_stat, - scoped_ptr<std::vector<CRCRecord> > crc32s_to_write); + scoped_ptr<std::vector<CRCRecord> > crc32s_to_write, + net::GrowableIOBuffer* stream_0_data); const base::FilePath& path() const { return path_; } std::string key() const { return key_; } private: + enum CreateEntryResult { + CREATE_ENTRY_SUCCESS = 0, + CREATE_ENTRY_PLATFORM_FILE_ERROR = 1, + CREATE_ENTRY_CANT_WRITE_HEADER = 2, + CREATE_ENTRY_CANT_WRITE_KEY = 3, + CREATE_ENTRY_MAX = 4, + }; + + enum FileRequired { + FILE_NOT_REQUIRED, + FILE_REQUIRED + }; + SimpleSynchronousEntry( + net::CacheType cache_type, const base::FilePath& path, const std::string& key, uint64 entry_hash); @@ -130,15 +175,36 @@ class SimpleSynchronousEntry { // called. ~SimpleSynchronousEntry(); - bool OpenOrCreateFiles(bool create, - bool had_index, - SimpleEntryStat* out_entry_stat); + // Tries to open one of the cache entry files. Succeeds if the open succeeds + // or if the file was not found and is allowed to be omitted if the + // corresponding stream is empty. + bool MaybeOpenFile(int file_index, + base::PlatformFileError* out_error); + // Creates one of the cache entry files if necessary. If the file is allowed + // to be omitted if the corresponding stream is empty, and if |file_required| + // is FILE_NOT_REQUIRED, then the file is not created; otherwise, it is. + bool MaybeCreateFile(int file_index, + FileRequired file_required, + base::PlatformFileError* out_error); + bool OpenFiles(bool had_index, + SimpleEntryStat* out_entry_stat); + bool CreateFiles(bool had_index, + SimpleEntryStat* out_entry_stat); + void CloseFile(int index); void CloseFiles(); // Returns a net error, i.e. net::OK on success. |had_index| is passed // from the main entry for metrics purposes, and is true if the index was // initialized when the open operation began. - int InitializeForOpen(bool had_index, SimpleEntryStat* out_entry_stat); + int InitializeForOpen(bool had_index, + SimpleEntryStat* out_entry_stat, + scoped_refptr<net::GrowableIOBuffer>* stream_0_data, + uint32* out_stream_0_crc32); + + // Writes the header and key to a newly-created stream file. |index| is the + // index of the stream. Returns true on success; returns false and sets + // |*out_result| on failure. + bool InitializeCreatedFile(int index, CreateEntryResult* out_result); // Returns a net error, including net::OK on success and net::FILE_EXISTS // when the entry already exists. |had_index| is passed from the main entry @@ -146,11 +212,32 @@ class SimpleSynchronousEntry { // create operation began. int InitializeForCreate(bool had_index, SimpleEntryStat* out_entry_stat); + // Allocates and fills a buffer with stream 0 data in |stream_0_data|, then + // checks its crc32. + int ReadAndValidateStream0( + int total_data_size, + SimpleEntryStat* out_entry_stat, + scoped_refptr<net::GrowableIOBuffer>* stream_0_data, + uint32* out_stream_0_crc32) const; + + int GetEOFRecordData(int index, + const SimpleEntryStat& entry_stat, + bool* out_has_crc32, + uint32* out_crc32, + int* out_data_size) const; void Doom() const; + static bool DeleteFileForEntryHash(const base::FilePath& path, + uint64 entry_hash, + int file_index); static bool DeleteFilesForEntryHash(const base::FilePath& path, uint64 entry_hash); + void RecordSyncCreateResult(CreateEntryResult result, bool had_index); + + base::FilePath GetFilenameFromFileIndex(int file_index); + + const net::CacheType cache_type_; const base::FilePath path_; const uint64 entry_hash_; std::string key_; @@ -159,6 +246,10 @@ class SimpleSynchronousEntry { bool initialized_; base::PlatformFile files_[kSimpleEntryFileCount]; + + // True if the corresponding stream is empty and therefore no on-disk file + // was created to store it. + bool empty_file_omitted_[kSimpleEntryFileCount]; }; } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_test_util.cc b/chromium/net/disk_cache/simple/simple_test_util.cc index 483cbec1cdd..9a275588230 100644 --- a/chromium/net/disk_cache/simple/simple_test_util.cc +++ b/chromium/net/disk_cache/simple/simple_test_util.cc @@ -8,13 +8,12 @@ #include "net/disk_cache/simple/simple_util.h" namespace disk_cache { - namespace simple_util { bool CreateCorruptFileForTests(const std::string& key, const base::FilePath& cache_path) { base::FilePath entry_file_path = cache_path.AppendASCII( - disk_cache::simple_util::GetFilenameFromKeyAndIndex(key, 0)); + disk_cache::simple_util::GetFilenameFromKeyAndFileIndex(key, 0)); int flags = base::PLATFORM_FILE_CREATE_ALWAYS | base::PLATFORM_FILE_WRITE; base::PlatformFile entry_file = base::CreatePlatformFile(entry_file_path, flags, NULL, NULL); @@ -30,5 +29,4 @@ bool CreateCorruptFileForTests(const std::string& key, } } // namespace simple_backend - } // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_test_util.h b/chromium/net/disk_cache/simple/simple_test_util.h index 98c140b1f14..82eebbec314 100644 --- a/chromium/net/disk_cache/simple/simple_test_util.h +++ b/chromium/net/disk_cache/simple/simple_test_util.h @@ -8,22 +8,41 @@ #include <string> #include "base/basictypes.h" -#include "net/base/net_export.h" +#include "base/callback.h" namespace base { class FilePath; } namespace disk_cache { - namespace simple_util { +// Immutable array with compile-time bound-checking. +template <typename T, size_t Size> +class ImmutableArray { + public: + static const size_t size = Size; + + ImmutableArray(const base::Callback<T (size_t index)>& initializer) { + for (size_t i = 0; i < size; ++i) + data_[i] = initializer.Run(i); + } + + template <size_t Index> + const T& at() const { + COMPILE_ASSERT(Index < size, array_out_of_bounds); + return data_[Index]; + } + + private: + T data_[size]; +}; + // Creates a corrupt file to be used in tests. bool CreateCorruptFileForTests(const std::string& key, const base::FilePath& cache_path); } // namespace simple_backend - } // namespace disk_cache #endif // NET_DISK_CACHE_SIMPLE_SIMPLE_TEST_UTIL_H_ diff --git a/chromium/net/disk_cache/simple/simple_util.cc b/chromium/net/disk_cache/simple/simple_util.cc index e9ec067635a..4291b1f7773 100644 --- a/chromium/net/disk_cache/simple/simple_util.cc +++ b/chromium/net/disk_cache/simple/simple_util.cc @@ -21,6 +21,25 @@ namespace { // Size of the uint64 hash_key number in Hex format in a string. const size_t kEntryHashKeyAsHexStringSize = 2 * sizeof(uint64); +// TODO(clamy, gavinp): this should go in base +bool GetNanoSecsFromStat(const struct stat& st, + time_t* out_sec, + long* out_nsec) { +#if defined(OS_ANDROID) + *out_sec = st.st_mtime; + *out_nsec = st.st_mtime_nsec; +#elif defined(OS_LINUX) + *out_sec = st.st_mtim.tv_sec; + *out_nsec = st.st_mtim.tv_nsec; +#elif defined(OS_MACOSX) || defined(OS_IOS) || defined(OS_BSD) + *out_sec = st.st_mtimespec.tv_sec; + *out_nsec = st.st_mtimespec.tv_nsec; +#else + return false; +#endif + return true; +} + } // namespace namespace disk_cache { @@ -58,13 +77,15 @@ uint64 GetEntryHashKey(const std::string& key) { return u.key_hash; } -std::string GetFilenameFromEntryHashAndIndex(uint64 entry_hash, - int index) { - return base::StringPrintf("%016" PRIx64 "_%1d", entry_hash, index); +std::string GetFilenameFromEntryHashAndFileIndex(uint64 entry_hash, + int file_index) { + return base::StringPrintf("%016" PRIx64 "_%1d", entry_hash, file_index); } -std::string GetFilenameFromKeyAndIndex(const std::string& key, int index) { - return GetEntryHashKeyAsHexString(key) + base::StringPrintf("_%1d", index); +std::string GetFilenameFromKeyAndFileIndex(const std::string& key, + int file_index) { + return GetEntryHashKeyAsHexString(key) + + base::StringPrintf("_%1d", file_index); } int32 GetDataSizeFromKeyAndFileSize(const std::string& key, int64 file_size) { @@ -79,15 +100,27 @@ int64 GetFileSizeFromKeyAndDataSize(const std::string& key, int32 data_size) { sizeof(SimpleFileEOF); } -int64 GetFileOffsetFromKeyAndDataOffset(const std::string& key, - int data_offset) { - const int64 headers_size = sizeof(disk_cache::SimpleFileHeader) + key.size(); - return headers_size + data_offset; +int GetFileIndexFromStreamIndex(int stream_index) { + return (stream_index == 2) ? 1 : 0; } // TODO(clamy, gavinp): this should go in base bool GetMTime(const base::FilePath& path, base::Time* out_mtime) { DCHECK(out_mtime); +#if defined(OS_POSIX) + base::ThreadRestrictions::AssertIOAllowed(); + struct stat file_stat; + if (stat(path.value().c_str(), &file_stat) != 0) + return false; + time_t sec; + long nsec; + if (GetNanoSecsFromStat(file_stat, &sec, &nsec)) { + int64 usec = (nsec / base::Time::kNanosecondsPerMicrosecond); + *out_mtime = base::Time::FromTimeT(sec) + + base::TimeDelta::FromMicroseconds(usec); + return true; + } +#endif base::PlatformFileInfo file_info; if (!file_util::GetFileInfo(path, &file_info)) return false; diff --git a/chromium/net/disk_cache/simple/simple_util.h b/chromium/net/disk_cache/simple/simple_util.h index 3bb80b95402..60a237ecd86 100644 --- a/chromium/net/disk_cache/simple/simple_util.h +++ b/chromium/net/disk_cache/simple/simple_util.h @@ -40,12 +40,13 @@ NET_EXPORT_PRIVATE bool GetEntryHashKeyFromHexString( // Given a |key| for a (potential) entry in the simple backend and the |index| // of a stream on that entry, returns the filename in which that stream would be // stored. -NET_EXPORT_PRIVATE std::string GetFilenameFromKeyAndIndex( +NET_EXPORT_PRIVATE std::string GetFilenameFromKeyAndFileIndex( const std::string& key, - int index); + int file_index); // Same as |GetFilenameFromKeyAndIndex| above, but using a hex string. -std::string GetFilenameFromEntryHashAndIndex(uint64 entry_hash, int index); +std::string GetFilenameFromEntryHashAndFileIndex(uint64 entry_hash, + int file_index); // Given the size of a file holding a stream in the simple backend and the key // to an entry, returns the number of bytes in the stream. @@ -57,14 +58,12 @@ NET_EXPORT_PRIVATE int32 GetDataSizeFromKeyAndFileSize(const std::string& key, NET_EXPORT_PRIVATE int64 GetFileSizeFromKeyAndDataSize(const std::string& key, int32 data_size); -// Given the key to an entry, and an offset into a stream on that entry, returns -// the file offset corresponding to |data_offset|. -NET_EXPORT_PRIVATE int64 GetFileOffsetFromKeyAndDataOffset( - const std::string& key, - int data_offset); +// Given the stream index, returns the number of the file the stream is stored +// in. +NET_EXPORT_PRIVATE int GetFileIndexFromStreamIndex(int stream_index); -// Fills |out_time| with the time the file last modified time. -// TODO(gavinp): Remove this function. +// Fills |out_time| with the time the file last modified time. Unlike the +// functions in platform_file.h, the time resolution is milliseconds. NET_EXPORT_PRIVATE bool GetMTime(const base::FilePath& path, base::Time* out_mtime); } // namespace simple_backend diff --git a/chromium/net/disk_cache/simple/simple_util_unittest.cc b/chromium/net/disk_cache/simple/simple_util_unittest.cc index 35388e9f172..ea039fc9ee6 100644 --- a/chromium/net/disk_cache/simple/simple_util_unittest.cc +++ b/chromium/net/disk_cache/simple/simple_util_unittest.cc @@ -30,7 +30,8 @@ TEST_F(SimpleUtilTest, ConvertEntryHashKeyToHexString) { TEST_F(SimpleUtilTest, GetEntryHashKey) { EXPECT_EQ("7ac408c1dff9c84b", GetEntryHashKeyAsHexString("http://www.amazon.com/")); - EXPECT_EQ(GG_UINT64_C(0x7ac408c1dff9c84b), GetEntryHashKey("http://www.amazon.com/")); + EXPECT_EQ(GG_UINT64_C(0x7ac408c1dff9c84b), + GetEntryHashKey("http://www.amazon.com/")); EXPECT_EQ("9fe947998c2ccf47", GetEntryHashKeyAsHexString("www.amazon.com")); diff --git a/chromium/net/disk_cache/simple/simple_version_upgrade.cc b/chromium/net/disk_cache/simple/simple_version_upgrade.cc new file mode 100644 index 00000000000..dfc6ef4394b --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_version_upgrade.cc @@ -0,0 +1,203 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/disk_cache/simple/simple_version_upgrade.h" + +#include <cstring> + +#include "base/file_util.h" +#include "base/files/file_path.h" +#include "base/files/memory_mapped_file.h" +#include "base/logging.h" +#include "base/pickle.h" +#include "net/disk_cache/simple/simple_backend_version.h" +#include "net/disk_cache/simple/simple_entry_format_history.h" +#include "third_party/zlib/zlib.h" + +namespace { + +// It is not possible to upgrade cache structures on disk that are of version +// below this, the entire cache should be dropped for them. +const uint32 kMinVersionAbleToUpgrade = 5; + +const char kFakeIndexFileName[] = "index"; +const char kIndexFileName[] = "the-real-index"; + +void LogMessageFailedUpgradeFromVersion(int version) { + LOG(ERROR) << "Failed to upgrade Simple Cache from version: " << version; +} + +bool WriteFakeIndexFile(const base::FilePath& file_name) { + base::PlatformFileError error; + base::PlatformFile file = base::CreatePlatformFile( + file_name, + base::PLATFORM_FILE_CREATE | base::PLATFORM_FILE_WRITE, + NULL, + &error); + disk_cache::FakeIndexData file_contents; + file_contents.initial_magic_number = + disk_cache::simplecache_v5::kSimpleInitialMagicNumber; + file_contents.version = disk_cache::kSimpleVersion; + int bytes_written = base::WritePlatformFile( + file, 0, reinterpret_cast<char*>(&file_contents), sizeof(file_contents)); + if (!base::ClosePlatformFile(file) || + bytes_written != sizeof(file_contents)) { + LOG(ERROR) << "Failed to write fake index file: " + << file_name.LossyDisplayName(); + return false; + } + return true; +} + +} // namespace + +namespace disk_cache { + +FakeIndexData::FakeIndexData() { + // Make hashing repeatable: leave no padding bytes untouched. + std::memset(this, 0, sizeof(*this)); +} + +// Migrates the cache directory from version 4 to version 5. +// Returns true iff it succeeds. +// +// The V5 and V6 caches differ in the name of the index file (it moved to a +// subdirectory) and in the file format (directory last-modified time observed +// by the index writer has gotten appended to the pickled format). +// +// To keep complexity small this specific upgrade code *deletes* the old index +// file. The directory for the new index file has to be created lazily anyway, +// so it is not done in the upgrader. +// +// Below is the detailed description of index file format differences. It is for +// reference purposes. This documentation would be useful to move closer to the +// next index upgrader when the latter gets introduced. +// +// Path: +// V5: $cachedir/the-real-index +// V6: $cachedir/index-dir/the-real-index +// +// Pickled file format: +// Both formats extend Pickle::Header by 32bit value of the CRC-32 of the +// pickled data. +// <v5-index> ::= <v5-index-metadata> <entry-info>* +// <v5-index-metadata> ::= UInt64(kSimpleIndexMagicNumber) +// UInt32(4) +// UInt64(<number-of-entries>) +// UInt64(<cache-size-in-bytes>) +// <entry-info> ::= UInt64(<hash-of-the-key>) +// Int64(<entry-last-used-time>) +// UInt64(<entry-size-in-bytes>) +// <v6-index> ::= <v6-index-metadata> +// <entry-info>* +// Int64(<cache-dir-mtime>) +// <v6-index-metadata> ::= UInt64(kSimpleIndexMagicNumber) +// UInt32(5) +// UInt64(<number-of-entries>) +// UInt64(<cache-size-in-bytes>) +// Where: +// <entry-size-in-bytes> is equal the sum of all file sizes of the entry. +// <cache-dir-mtime> is the last modification time with nanosecond precision +// of the directory, where all files for entries are stored. +// <hash-of-the-key> represent the first 64 bits of a SHA-1 of the key. +bool UpgradeIndexV5V6(const base::FilePath& cache_directory) { + const base::FilePath old_index_file = + cache_directory.AppendASCII(kIndexFileName); + if (!base::DeleteFile(old_index_file, /* recursive = */ false)) + return false; + return true; +} + +// Some points about the Upgrade process are still not clear: +// 1. if the upgrade path requires dropping cache it would be faster to just +// return an initialization error here and proceed with asynchronous cache +// cleanup in CacheCreator. Should this hack be considered valid? Some smart +// tests may fail. +// 2. Because Android process management allows for killing a process at any +// time, the upgrade process may need to deal with a partially completed +// previous upgrade. For example, while upgrading A -> A + 2 we are the +// process gets killed and some parts are remaining at version A + 1. There +// are currently no generic mechanisms to resolve this situation, co the +// upgrade codes need to ensure they can continue after being stopped in the +// middle. It also means that the "fake index" must be flushed in between the +// upgrade steps. Atomicity of this is an interesting research topic. The +// intermediate fake index flushing must be added as soon as we add more +// upgrade steps. +bool UpgradeSimpleCacheOnDisk(const base::FilePath& path) { + // There is a convention among disk cache backends: looking at the magic in + // the file "index" it should be sufficient to determine if the cache belongs + // to the currently running backend. The Simple Backend stores its index in + // the file "the-real-index" (see simple_index_file.cc) and the file "index" + // only signifies presence of the implementation's magic and version. There + // are two reasons for that: + // 1. Absence of the index is itself not a fatal error in the Simple Backend + // 2. The Simple Backend has pickled file format for the index making it hacky + // to have the magic in the right place. + const base::FilePath fake_index = path.AppendASCII(kFakeIndexFileName); + base::PlatformFileError error; + base::PlatformFile fake_index_file = base::CreatePlatformFile( + fake_index, + base::PLATFORM_FILE_OPEN | base::PLATFORM_FILE_READ, + NULL, + &error); + if (error == base::PLATFORM_FILE_ERROR_NOT_FOUND) { + return WriteFakeIndexFile(fake_index); + } else if (error != base::PLATFORM_FILE_OK) { + return false; + } + FakeIndexData file_header; + int bytes_read = base::ReadPlatformFile(fake_index_file, + 0, + reinterpret_cast<char*>(&file_header), + sizeof(file_header)); + if (!base::ClosePlatformFile(fake_index_file) || + bytes_read != sizeof(file_header) || + file_header.initial_magic_number != + disk_cache::simplecache_v5::kSimpleInitialMagicNumber) { + LOG(ERROR) << "File structure does not match the disk cache backend."; + return false; + } + + uint32 version_from = file_header.version; + if (version_from < kMinVersionAbleToUpgrade || + version_from > kSimpleVersion) { + LOG(ERROR) << "Inconsistent cache version."; + return false; + } + bool upgrade_needed = (version_from != kSimpleVersion); + if (version_from == kMinVersionAbleToUpgrade) { + // Upgrade only the index for V4 -> V5 move. + if (!UpgradeIndexV5V6(path)) { + LogMessageFailedUpgradeFromVersion(file_header.version); + return false; + } + version_from++; + } + if (version_from == kSimpleVersion) { + if (!upgrade_needed) { + return true; + } else { + const base::FilePath temp_fake_index = path.AppendASCII("upgrade-index"); + if (!WriteFakeIndexFile(temp_fake_index)) { + base::DeleteFile(temp_fake_index, /* recursive = */ false); + LOG(ERROR) << "Failed to write a new fake index."; + LogMessageFailedUpgradeFromVersion(file_header.version); + return false; + } + if (!base::ReplaceFile(temp_fake_index, fake_index, NULL)) { + LOG(ERROR) << "Failed to replace the fake index."; + LogMessageFailedUpgradeFromVersion(file_header.version); + return false; + } + return true; + } + } + // Verify during the test stage that the upgraders are implemented for all + // versions. The release build would cause backend initialization failure + // which would then later lead to removing all files known to the backend. + DCHECK_EQ(kSimpleVersion, version_from); + return false; +} + +} // namespace disk_cache diff --git a/chromium/net/disk_cache/simple/simple_version_upgrade.h b/chromium/net/disk_cache/simple/simple_version_upgrade.h new file mode 100644 index 00000000000..352379b997b --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_version_upgrade.h @@ -0,0 +1,50 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_DISK_CACHE_SIMPLE_SIMPLE_VERSION_UPGRADE_H_ +#define NET_DISK_CACHE_SIMPLE_SIMPLE_VERSION_UPGRADE_H_ + +// Defines functionality to upgrade the file structure of the Simple Cache +// Backend on disk. Assumes no backend operations are running simultaneously. +// Hence must be run at cache initialization step. + +#include "base/basictypes.h" +#include "net/base/net_export.h" + +namespace base { +class FilePath; +} + +namespace disk_cache { + +// Performs all necessary disk IO to upgrade the cache structure if it is +// needed. +// +// Returns true iff no errors were found during consistency checks and all +// necessary transitions succeeded. If this function fails, there is nothing +// left to do other than dropping the whole cache directory. +NET_EXPORT_PRIVATE bool UpgradeSimpleCacheOnDisk(const base::FilePath& path); + +// The format for the fake index has mistakenly acquired two extra fields that +// do not contain any useful data. Since they were equal to zero, they are now +// mandatated to be zero. +struct NET_EXPORT_PRIVATE FakeIndexData { + FakeIndexData(); + + // Must be equal to simplecache_v4::kSimpleInitialMagicNumber. + uint64 initial_magic_number; + + // Must be equal kSimpleVersion when the cache backend is instantiated. + uint32 version; + + uint32 unused_must_be_zero1; + uint32 unused_must_be_zero2; +}; + +// Exposed for testing. +NET_EXPORT_PRIVATE bool UpgradeIndexV5V6(const base::FilePath& cache_directory); + +} // namespace disk_cache + +#endif // NET_DISK_CACHE_SIMPLE_SIMPLE_VERSION_UPGRADE_H_ diff --git a/chromium/net/disk_cache/simple/simple_version_upgrade_unittest.cc b/chromium/net/disk_cache/simple/simple_version_upgrade_unittest.cc new file mode 100644 index 00000000000..c9d42f10660 --- /dev/null +++ b/chromium/net/disk_cache/simple/simple_version_upgrade_unittest.cc @@ -0,0 +1,146 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/disk_cache/simple/simple_version_upgrade.h" + +#include "base/basictypes.h" +#include "base/file_util.h" +#include "base/files/file_path.h" +#include "base/files/scoped_temp_dir.h" +#include "base/format_macros.h" +#include "base/strings/stringprintf.h" +#include "net/base/net_errors.h" +#include "net/disk_cache/simple/simple_backend_version.h" +#include "net/disk_cache/simple/simple_entry_format_history.h" +#include "testing/gtest/include/gtest/gtest.h" + +// The migration process relies on ability to rename newly created files, which +// could be problematic on Windows XP. +#if defined(OS_POSIX) + +namespace { + +// Same as |disk_cache::kSimpleInitialMagicNumber|. +const uint64 kSimpleInitialMagicNumber = GG_UINT64_C(0xfcfb6d1ba7725c30); + +// The "fake index" file that cache backends use to distinguish whether the +// cache belongs to one backend or another. +const char kFakeIndexFileName[] = "index"; + +// Same as |SimpleIndexFile::kIndexFileName|. +const char kIndexFileName[] = "the-real-index"; + +// Same as |SimpleIndexFile::kIndexDirectory|. +const char kIndexDirectory[] = "index-dir"; + +// Same as |SimpleIndexFile::kTempIndexFileName|. +const char kTempIndexFileName[] = "temp-index"; + +bool WriteFakeIndexFileV5(const base::FilePath& cache_path) { + disk_cache::FakeIndexData data; + data.version = 5; + data.initial_magic_number = kSimpleInitialMagicNumber; + data.unused_must_be_zero1 = 0; + data.unused_must_be_zero2 = 0; + const base::FilePath file_name = cache_path.AppendASCII("index"); + return sizeof(data) == + file_util::WriteFile( + file_name, reinterpret_cast<const char*>(&data), sizeof(data)); +} + +TEST(SimpleVersionUpgradeTest, FailsToMigrateBackwards) { + base::ScopedTempDir cache_dir; + ASSERT_TRUE(cache_dir.CreateUniqueTempDir()); + const base::FilePath cache_path = cache_dir.path(); + + disk_cache::FakeIndexData data; + data.version = 100500; + data.initial_magic_number = kSimpleInitialMagicNumber; + data.unused_must_be_zero1 = 0; + data.unused_must_be_zero2 = 0; + const base::FilePath file_name = cache_path.AppendASCII(kFakeIndexFileName); + ASSERT_EQ(implicit_cast<int>(sizeof(data)), + file_util::WriteFile( + file_name, reinterpret_cast<const char*>(&data), sizeof(data))); + EXPECT_FALSE(disk_cache::UpgradeSimpleCacheOnDisk(cache_dir.path())); +} + +TEST(SimpleVersionUpgradeTest, FakeIndexVersionGetsUpdated) { + base::ScopedTempDir cache_dir; + ASSERT_TRUE(cache_dir.CreateUniqueTempDir()); + const base::FilePath cache_path = cache_dir.path(); + + WriteFakeIndexFileV5(cache_path); + const std::string file_contents("incorrectly serialized data"); + const base::FilePath index_file = cache_path.AppendASCII(kIndexFileName); + ASSERT_EQ(implicit_cast<int>(file_contents.size()), + file_util::WriteFile( + index_file, file_contents.data(), file_contents.size())); + + // Upgrade. + ASSERT_TRUE(disk_cache::UpgradeSimpleCacheOnDisk(cache_path)); + + // Check that the version in the fake index file is updated. + std::string new_fake_index_contents; + ASSERT_TRUE(base::ReadFileToString(cache_path.AppendASCII(kFakeIndexFileName), + &new_fake_index_contents)); + const disk_cache::FakeIndexData* fake_index_header; + EXPECT_EQ(sizeof(*fake_index_header), new_fake_index_contents.size()); + fake_index_header = reinterpret_cast<const disk_cache::FakeIndexData*>( + new_fake_index_contents.data()); + EXPECT_EQ(disk_cache::kSimpleVersion, fake_index_header->version); + EXPECT_EQ(kSimpleInitialMagicNumber, fake_index_header->initial_magic_number); +} + +TEST(SimpleVersionUpgradeTest, UpgradeV5V6IndexMustDisappear) { + base::ScopedTempDir cache_dir; + ASSERT_TRUE(cache_dir.CreateUniqueTempDir()); + const base::FilePath cache_path = cache_dir.path(); + + WriteFakeIndexFileV5(cache_path); + const std::string file_contents("incorrectly serialized data"); + const base::FilePath index_file = cache_path.AppendASCII(kIndexFileName); + ASSERT_EQ(implicit_cast<int>(file_contents.size()), + file_util::WriteFile( + index_file, file_contents.data(), file_contents.size())); + + // Create a few entry-like files. + const uint64 kEntries = 5; + for (uint64 entry_hash = 0; entry_hash < kEntries; ++entry_hash) { + for (int index = 0; index < 3; ++index) { + std::string file_name = + base::StringPrintf("%016" PRIx64 "_%1d", entry_hash, index); + std::string entry_contents = + file_contents + + base::StringPrintf(" %" PRIx64, implicit_cast<uint64>(entry_hash)); + ASSERT_EQ(implicit_cast<int>(entry_contents.size()), + file_util::WriteFile(cache_path.AppendASCII(file_name), + entry_contents.data(), + entry_contents.size())); + } + } + + // Upgrade. + ASSERT_TRUE(disk_cache::UpgradeIndexV5V6(cache_path)); + + // Check that the old index disappeared but the files remain unchanged. + EXPECT_FALSE(base::PathExists(index_file)); + for (uint64 entry_hash = 0; entry_hash < kEntries; ++entry_hash) { + for (int index = 0; index < 3; ++index) { + std::string file_name = + base::StringPrintf("%016" PRIx64 "_%1d", entry_hash, index); + std::string expected_contents = + file_contents + + base::StringPrintf(" %" PRIx64, implicit_cast<uint64>(entry_hash)); + std::string real_contents; + EXPECT_TRUE(base::ReadFileToString(cache_path.AppendASCII(file_name), + &real_contents)); + EXPECT_EQ(expected_contents, real_contents); + } + } +} + +} // namespace + +#endif // defined(OS_POSIX) diff --git a/chromium/net/disk_cache/stats.cc b/chromium/net/disk_cache/stats.cc index 33d9d1c534a..b05ecc6ec0b 100644 --- a/chromium/net/disk_cache/stats.cc +++ b/chromium/net/disk_cache/stats.cc @@ -92,6 +92,9 @@ Stats::Stats() : size_histogram_(NULL) { } Stats::~Stats() { + if (size_histogram_) { + size_histogram_->Disable(); + } } bool Stats::Init(void* data, int num_bytes, Addr address) { diff --git a/chromium/net/disk_cache/stats_histogram.cc b/chromium/net/disk_cache/stats_histogram.cc index 33adfeaae49..2a675501955 100644 --- a/chromium/net/disk_cache/stats_histogram.cc +++ b/chromium/net/disk_cache/stats_histogram.cc @@ -70,9 +70,14 @@ StatsHistogram* StatsHistogram::FactoryGet(const std::string& name, return return_histogram; } +void StatsHistogram::Disable() { + stats_ = NULL; +} + scoped_ptr<HistogramSamples> StatsHistogram::SnapshotSamples() const { scoped_ptr<SampleVector> samples(new SampleVector(bucket_ranges())); - stats_->Snapshot(samples.get()); + if (stats_) + stats_->Snapshot(samples.get()); // Only report UMA data once. StatsHistogram* mutable_me = const_cast<StatsHistogram*>(this); diff --git a/chromium/net/disk_cache/stats_histogram.h b/chromium/net/disk_cache/stats_histogram.h index 279a1c3c71c..2e481f52e42 100644 --- a/chromium/net/disk_cache/stats_histogram.h +++ b/chromium/net/disk_cache/stats_histogram.h @@ -41,6 +41,9 @@ class StatsHistogram : public base::Histogram { static StatsHistogram* FactoryGet(const std::string& name, const Stats* stats); + // Disables this histogram when the underlying Stats go away. + void Disable(); + virtual scoped_ptr<base::HistogramSamples> SnapshotSamples() const OVERRIDE; virtual int FindCorruption( const base::HistogramSamples& samples) const OVERRIDE; diff --git a/chromium/net/disk_cache/storage_block-inl.h b/chromium/net/disk_cache/storage_block-inl.h index 098cd74afa7..b9061861d8f 100644 --- a/chromium/net/disk_cache/storage_block-inl.h +++ b/chromium/net/disk_cache/storage_block-inl.h @@ -143,6 +143,36 @@ template<typename T> bool StorageBlock<T>::Store() { return false; } +template<typename T> bool StorageBlock<T>::Load(FileIOCallback* callback, + bool* completed) { + if (file_) { + if (!data_) + AllocateData(); + + if (file_->Load(this, callback, completed)) { + modified_ = false; + return true; + } + } + LOG(WARNING) << "Failed data load."; + Trace("Failed data load."); + return false; +} + +template<typename T> bool StorageBlock<T>::Store(FileIOCallback* callback, + bool* completed) { + if (file_ && data_) { + data_->self_hash = CalculateHash(); + if (file_->Store(this, callback, completed)) { + modified_ = false; + return true; + } + } + LOG(ERROR) << "Failed data store."; + Trace("Failed data store."); + return false; +} + template<typename T> void StorageBlock<T>::AllocateData() { DCHECK(!data_); if (!extended_) { diff --git a/chromium/net/disk_cache/storage_block.h b/chromium/net/disk_cache/storage_block.h index 65c67fc4b44..f7690ed5c4d 100644 --- a/chromium/net/disk_cache/storage_block.h +++ b/chromium/net/disk_cache/storage_block.h @@ -74,6 +74,8 @@ class StorageBlock : public FileBlock { // Loads and store the data. bool Load(); bool Store(); + bool Load(FileIOCallback* callback, bool* completed); + bool Store(FileIOCallback* callback, bool* completed); private: void AllocateData(); diff --git a/chromium/net/disk_cache/v3/block_bitmaps.cc b/chromium/net/disk_cache/v3/block_bitmaps.cc index 0d0317b39dc..b68ecdd14f9 100644 --- a/chromium/net/disk_cache/v3/block_bitmaps.cc +++ b/chromium/net/disk_cache/v3/block_bitmaps.cc @@ -2,143 +2,73 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/disk_cache/block_files.h" +#include "net/disk_cache/v3/block_bitmaps.h" -#include "base/atomicops.h" -#include "base/file_util.h" #include "base/metrics/histogram.h" -#include "base/strings/string_util.h" -#include "base/strings/stringprintf.h" -#include "base/threading/thread_checker.h" #include "base/time/time.h" -#include "net/disk_cache/cache_util.h" -#include "net/disk_cache/file_lock.h" +#include "net/disk_cache/disk_format_base.h" #include "net/disk_cache/trace.h" using base::TimeTicks; namespace disk_cache { -BlockFiles::BlockFiles(const base::FilePath& path) - : init_(false), zero_buffer_(NULL), path_(path) { +BlockBitmaps::BlockBitmaps() { } -BlockFiles::~BlockFiles() { - if (zero_buffer_) - delete[] zero_buffer_; - CloseFiles(); +BlockBitmaps::~BlockBitmaps() { } -bool BlockFiles::Init(bool create_files) { - DCHECK(!init_); - if (init_) - return false; - - thread_checker_.reset(new base::ThreadChecker); - - block_files_.resize(kFirstAdditionalBlockFile); - for (int i = 0; i < kFirstAdditionalBlockFile; i++) { - if (create_files) - if (!CreateBlockFile(i, static_cast<FileType>(i + 1), true)) - return false; - - if (!OpenBlockFile(i)) - return false; - - // Walk this chain of files removing empty ones. - if (!RemoveEmptyFile(static_cast<FileType>(i + 1))) - return false; - } - - init_ = true; - return true; +void BlockBitmaps::Init(const BlockFilesBitmaps& bitmaps) { + bitmaps_ = bitmaps; } -bool BlockFiles::CreateBlock(FileType block_type, int block_count, - Addr* block_address) { - DCHECK(thread_checker_->CalledOnValidThread()); - if (block_type < RANKINGS || block_type > BLOCK_4K || - block_count < 1 || block_count > 4) - return false; - if (!init_) +bool BlockBitmaps::CreateBlock(FileType block_type, + int block_count, + Addr* block_address) { + DCHECK_NE(block_type, EXTERNAL); + DCHECK_NE(block_type, RANKINGS); + if (block_count < 1 || block_count > kMaxNumBlocks) return false; - MappedFile* file = FileForNewBlock(block_type, block_count); - if (!file) + int header_num = HeaderNumberForNewBlock(block_type, block_count); + if (header_num < 0) return false; - ScopedFlush flush(file); - BlockFileHeader* header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - - int target_size = 0; - for (int i = block_count; i <= 4; i++) { - if (header->empty[i - 1]) { - target_size = i; - break; - } - } - - DCHECK(target_size); int index; - if (!CreateMapBlock(target_size, block_count, header, &index)) + if (!bitmaps_[header_num].CreateMapBlock(block_count, &index)) + return false; + + if (!index && (block_type == BLOCK_ENTRIES || block_type == BLOCK_EVICTED) && + !bitmaps_[header_num].CreateMapBlock(block_count, &index)) { + // index 0 for entries is a reserved value. return false; + } - Addr address(block_type, block_count, header->this_file, index); + Addr address(block_type, block_count, bitmaps_[header_num].FileId(), index); block_address->set_value(address.value()); Trace("CreateBlock 0x%x", address.value()); return true; } -void BlockFiles::DeleteBlock(Addr address, bool deep) { - DCHECK(thread_checker_->CalledOnValidThread()); +void BlockBitmaps::DeleteBlock(Addr address) { if (!address.is_initialized() || address.is_separate_file()) return; - if (!zero_buffer_) { - zero_buffer_ = new char[Addr::BlockSizeForFileType(BLOCK_4K) * 4]; - memset(zero_buffer_, 0, Addr::BlockSizeForFileType(BLOCK_4K) * 4); - } - MappedFile* file = GetFile(address); - if (!file) + int header_num = GetHeaderNumber(address); + if (header_num < 0) return; Trace("DeleteBlock 0x%x", address.value()); - - size_t size = address.BlockSize() * address.num_blocks(); - size_t offset = address.start_block() * address.BlockSize() + - kBlockHeaderSize; - if (deep) - file->Write(zero_buffer_, size, offset); - - BlockFileHeader* header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - DeleteMapBlock(address.start_block(), address.num_blocks(), header); - file->Flush(); - - if (!header->num_entries) { - // This file is now empty. Let's try to delete it. - FileType type = Addr::RequiredFileType(header->entry_size); - if (Addr::BlockSizeForFileType(RANKINGS) == header->entry_size) - type = RANKINGS; - RemoveEmptyFile(type); // Ignore failures. - } + bitmaps_[header_num].DeleteMapBlock(address.start_block(), + address.num_blocks()); } -void BlockFiles::CloseFiles() { - if (init_) { - DCHECK(thread_checker_->CalledOnValidThread()); - } - init_ = false; - for (unsigned int i = 0; i < block_files_.size(); i++) { - if (block_files_[i]) { - block_files_[i]->Release(); - block_files_[i] = NULL; - } - } - block_files_.clear(); +void BlockBitmaps::Clear() { + bitmaps_.clear(); } -void BlockFiles::ReportStats() { - DCHECK(thread_checker_->CalledOnValidThread()); +void BlockBitmaps::ReportStats() { int used_blocks[kFirstAdditionalBlockFile]; int load[kFirstAdditionalBlockFile]; for (int i = 0; i < kFirstAdditionalBlockFile; i++) { @@ -155,176 +85,92 @@ void BlockFiles::ReportStats() { UMA_HISTOGRAM_ENUMERATION("DiskCache.BlockLoad_3", load[3], 101); } -bool BlockFiles::IsValid(Addr address) { +bool BlockBitmaps::IsValid(Addr address) { #ifdef NDEBUG return true; #else if (!address.is_initialized() || address.is_separate_file()) return false; - MappedFile* file = GetFile(address); - if (!file) + int header_num = GetHeaderNumber(address); + if (header_num < 0) return false; - BlockFileHeader* header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - bool rv = UsedMapBlock(address.start_block(), address.num_blocks(), header); + bool rv = bitmaps_[header_num].UsedMapBlock(address.start_block(), + address.num_blocks()); DCHECK(rv); - - static bool read_contents = false; - if (read_contents) { - scoped_ptr<char[]> buffer; - buffer.reset(new char[Addr::BlockSizeForFileType(BLOCK_4K) * 4]); - size_t size = address.BlockSize() * address.num_blocks(); - size_t offset = address.start_block() * address.BlockSize() + - kBlockHeaderSize; - bool ok = file->Read(buffer.get(), size, offset); - DCHECK(ok); - } - return rv; #endif } -MappedFile* BlockFiles::GetFile(Addr address) { - DCHECK(thread_checker_->CalledOnValidThread()); - DCHECK(block_files_.size() >= 4); +int BlockBitmaps::GetHeaderNumber(Addr address) { + DCHECK_GE(bitmaps_.size(), static_cast<size_t>(kFirstAdditionalBlockFileV3)); DCHECK(address.is_block_file() || !address.is_initialized()); if (!address.is_initialized()) - return NULL; + return -1; int file_index = address.FileNumber(); - if (static_cast<unsigned int>(file_index) >= block_files_.size() || - !block_files_[file_index]) { - // We need to open the file - if (!OpenBlockFile(file_index)) - return NULL; - } - DCHECK(block_files_.size() >= static_cast<unsigned int>(file_index)); - return block_files_[file_index]; -} - -bool BlockFiles::GrowBlockFile(MappedFile* file, BlockFileHeader* header) { - if (kMaxBlocks == header->max_entries) - return false; - - ScopedFlush flush(file); - DCHECK(!header->empty[3]); - int new_size = header->max_entries + 1024; - if (new_size > kMaxBlocks) - new_size = kMaxBlocks; - - int new_size_bytes = new_size * header->entry_size + sizeof(*header); - - if (!file->SetLength(new_size_bytes)) { - // Most likely we are trying to truncate the file, so the header is wrong. - if (header->updating < 10 && !FixBlockFileHeader(file)) { - // If we can't fix the file increase the lock guard so we'll pick it on - // the next start and replace it. - header->updating = 100; - return false; - } - return (header->max_entries >= new_size); - } + if (static_cast<unsigned int>(file_index) >= bitmaps_.size()) + return -1; - FileLock lock(header); - header->empty[3] = (new_size - header->max_entries) / 4; // 4 blocks entries - header->max_entries = new_size; - - return true; + return file_index; } -MappedFile* BlockFiles::FileForNewBlock(FileType block_type, int block_count) { - COMPILE_ASSERT(RANKINGS == 1, invalid_file_type); - MappedFile* file = block_files_[block_type - 1]; - BlockFileHeader* header = reinterpret_cast<BlockFileHeader*>(file->buffer()); +int BlockBitmaps::HeaderNumberForNewBlock(FileType block_type, + int block_count) { + DCHECK_GT(block_type, 0); + int header_num = block_type - 1; + bool found = true; TimeTicks start = TimeTicks::Now(); - while (NeedToGrowBlockFile(header, block_count)) { - if (kMaxBlocks == header->max_entries) { - file = NextFile(file); - if (!file) - return NULL; - header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - continue; + while (bitmaps_[header_num].NeedToGrowBlockFile(block_count)) { + header_num = bitmaps_[header_num].NextFileId(); + if (!header_num) { + found = false; + break; } - - if (!GrowBlockFile(file, header)) - return NULL; - break; } - HISTOGRAM_TIMES("DiskCache.GetFileForNewBlock", TimeTicks::Now() - start); - return file; -} -// Note that we expect to be called outside of a FileLock... however, we cannot -// DCHECK on header->updating because we may be fixing a crash. -bool BlockFiles::FixBlockFileHeader(MappedFile* file) { - ScopedFlush flush(file); - BlockFileHeader* header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - int file_size = static_cast<int>(file->GetLength()); - if (file_size < static_cast<int>(sizeof(*header))) - return false; // file_size > 2GB is also an error. - - const int kMinBlockSize = 36; - const int kMaxBlockSize = 4096; - if (header->entry_size < kMinBlockSize || - header->entry_size > kMaxBlockSize || header->num_entries < 0) - return false; - - // Make sure that we survive crashes. - header->updating = 1; - int expected = header->entry_size * header->max_entries + sizeof(*header); - if (file_size != expected) { - int max_expected = header->entry_size * kMaxBlocks + sizeof(*header); - if (file_size < expected || header->empty[3] || file_size > max_expected) { - NOTREACHED(); - LOG(ERROR) << "Unexpected file size"; - return false; - } - // We were in the middle of growing the file. - int num_entries = (file_size - sizeof(*header)) / header->entry_size; - header->max_entries = num_entries; + if (!found) { + // Restart the search, looking for any file with space. We know that all + // files of this type are low on free blocks, but we cannot grow any file + // at this time. + header_num = block_type - 1; + do { + if (bitmaps_[header_num].CanAllocate(block_count)) { + found = true; // Make sure file 0 is not mistaken with a failure. + break; + } + header_num = bitmaps_[header_num].NextFileId(); + } while (header_num); + + if (!found) + header_num = -1; } - FixAllocationCounters(header); - int empty_blocks = EmptyBlocks(header); - if (empty_blocks + header->num_entries > header->max_entries) - header->num_entries = header->max_entries - empty_blocks; - - if (!ValidateCounters(header)) - return false; - - header->updating = 0; - return true; + HISTOGRAM_TIMES("DiskCache.GetFileForNewBlock", TimeTicks::Now() - start); + return header_num; } // We are interested in the total number of blocks used by this file type, and // the max number of blocks that we can store (reported as the percentage of // used blocks). In order to find out the number of used blocks, we have to // substract the empty blocks from the total blocks for each file in the chain. -void BlockFiles::GetFileStats(int index, int* used_count, int* load) { +void BlockBitmaps::GetFileStats(int index, int* used_count, int* load) { int max_blocks = 0; *used_count = 0; *load = 0; - for (;;) { - if (!block_files_[index] && !OpenBlockFile(index)) - return; - - BlockFileHeader* header = - reinterpret_cast<BlockFileHeader*>(block_files_[index]->buffer()); - - max_blocks += header->max_entries; - int used = header->max_entries; - for (int i = 0; i < 4; i++) { - used -= header->empty[i] * (i + 1); - DCHECK_GE(used, 0); - } + do { + int capacity = bitmaps_[index].Capacity(); + int used = capacity - bitmaps_[index].EmptyBlocks(); + DCHECK_GE(used, 0); + + max_blocks += capacity; *used_count += used; - if (!header->next_file) - break; - index = header->next_file; - } + index = bitmaps_[index].NextFileId(); + } while (index); + if (max_blocks) *load = *used_count * 100 / max_blocks; } diff --git a/chromium/net/disk_cache/v3/block_bitmaps.h b/chromium/net/disk_cache/v3/block_bitmaps.h index eaf87609912..111d57b0b13 100644 --- a/chromium/net/disk_cache/v3/block_bitmaps.h +++ b/chromium/net/disk_cache/v3/block_bitmaps.h @@ -4,43 +4,39 @@ // See net/disk_cache/disk_cache.h for the public interface. -#ifndef NET_DISK_CACHE_BLOCK_FILES_H_ -#define NET_DISK_CACHE_BLOCK_FILES_H_ - -#include <vector> +#ifndef NET_DISK_CACHE_V3_BLOCK_BITMAPS_H_ +#define NET_DISK_CACHE_V3_BLOCK_BITMAPS_H_ #include "base/files/file_path.h" -#include "base/gtest_prod_util.h" -#include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" #include "net/disk_cache/addr.h" -#include "net/disk_cache/mapped_file.h" +#include "net/disk_cache/block_files.h" namespace disk_cache { -// This class handles the set of block-files open by the disk cache. -class NET_EXPORT_PRIVATE BlockFiles { +class BackendImplV3; + +// This class is the interface in the v3 disk cache to the set of files holding +// cached data that is small enough to not be efficiently stored in a dedicated +// file (i.e. < kMaxBlockSize). It is primarily used to allocate and free +// regions in those files used to store data. +class NET_EXPORT_PRIVATE BlockBitmaps { public: - explicit BlockFiles(const base::FilePath& path); - ~BlockFiles(); + BlockBitmaps(); + ~BlockBitmaps(); - // Performs the object initialization. create_files indicates if the backing - // files should be created or just open. - bool Init(bool create_files); + void Init(const BlockFilesBitmaps& bitmaps); // Creates a new entry on a block file. block_type indicates the size of block // to be used (as defined on cache_addr.h), block_count is the number of // blocks to allocate, and block_address is the address of the new entry. bool CreateBlock(FileType block_type, int block_count, Addr* block_address); - // Removes an entry from the block files. If deep is true, the storage is zero - // filled; otherwise the entry is removed but the data is not altered (must be - // already zeroed). - void DeleteBlock(Addr address, bool deep); + // Removes an entry from the block files. + void DeleteBlock(Addr address); - // Close all the files and set the internal state to be initializad again. The - // cache is being purged. - void CloseFiles(); + // Releases the internal bitmaps. The cache is being purged. + void Clear(); // Sends UMA stats. void ReportStats(); @@ -50,26 +46,20 @@ class NET_EXPORT_PRIVATE BlockFiles { bool IsValid(Addr address); private: - // Returns the file that stores a given address. - MappedFile* GetFile(Addr address); - - // Attemp to grow this file. Fails if the file cannot be extended anymore. - bool GrowBlockFile(MappedFile* file, BlockFileHeader* header); - - // Returns the appropriate file to use for a new block. - MappedFile* FileForNewBlock(FileType block_type, int block_count); + // Returns the header number that stores a given address. + int GetHeaderNumber(Addr address); - // Restores the header of a potentially inconsistent file. - bool FixBlockFileHeader(MappedFile* file); + // Returns the appropriate header to use for a new block. + int HeaderNumberForNewBlock(FileType block_type, int block_count); // Retrieves stats for the given file index. void GetFileStats(int index, int* used_count, int* load); - bool init_; + BlockFilesBitmaps bitmaps_; - DISALLOW_COPY_AND_ASSIGN(BlockFiles); + DISALLOW_COPY_AND_ASSIGN(BlockBitmaps); }; } // namespace disk_cache -#endif // NET_DISK_CACHE_BLOCK_FILES_H_ +#endif // NET_DISK_CACHE_V3_BLOCK_BITMAPS_H_ diff --git a/chromium/net/disk_cache/v3/block_bitmaps_unittest.cc b/chromium/net/disk_cache/v3/block_bitmaps_unittest.cc index fa7c5dbb742..981bdecfcea 100644 --- a/chromium/net/disk_cache/v3/block_bitmaps_unittest.cc +++ b/chromium/net/disk_cache/v3/block_bitmaps_unittest.cc @@ -2,342 +2,64 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "base/file_util.h" -#include "base/files/file_enumerator.h" +#include "net/disk_cache/addr.h" #include "net/disk_cache/block_files.h" -#include "net/disk_cache/disk_cache.h" -#include "net/disk_cache/disk_cache_test_base.h" -#include "net/disk_cache/disk_cache_test_util.h" +#include "net/disk_cache/disk_format_base.h" +#include "net/disk_cache/v3/block_bitmaps.h" #include "testing/gtest/include/gtest/gtest.h" -using base::Time; - -namespace { - -// Returns the number of files in this folder. -int NumberOfFiles(const base::FilePath& path) { - base::FileEnumerator iter(path, false, base::FileEnumerator::FILES); - int count = 0; - for (base::FilePath file = iter.Next(); !file.value().empty(); - file = iter.Next()) { - count++; - } - return count; -} - -} // namespace; - -namespace disk_cache { - -TEST_F(DiskCacheTest, BlockFiles_Grow) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - - const int kMaxSize = 35000; - Addr address[kMaxSize]; - - // Fill up the 32-byte block file (use three files). - for (int i = 0; i < kMaxSize; i++) { - EXPECT_TRUE(files.CreateBlock(RANKINGS, 4, &address[i])); - } - EXPECT_EQ(6, NumberOfFiles(cache_path_)); - - // Make sure we don't keep adding files. - for (int i = 0; i < kMaxSize * 4; i += 2) { - int target = i % kMaxSize; - files.DeleteBlock(address[target], false); - EXPECT_TRUE(files.CreateBlock(RANKINGS, 4, &address[target])); - } - EXPECT_EQ(6, NumberOfFiles(cache_path_)); -} - -// We should be able to delete empty block files. -TEST_F(DiskCacheTest, BlockFiles_Shrink) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - - const int kMaxSize = 35000; - Addr address[kMaxSize]; - - // Fill up the 32-byte block file (use three files). - for (int i = 0; i < kMaxSize; i++) { - EXPECT_TRUE(files.CreateBlock(RANKINGS, 4, &address[i])); - } - - // Now delete all the blocks, so that we can delete the two extra files. - for (int i = 0; i < kMaxSize; i++) { - files.DeleteBlock(address[i], false); - } - EXPECT_EQ(4, NumberOfFiles(cache_path_)); -} - -// Handling of block files not properly closed. -TEST_F(DiskCacheTest, BlockFiles_Recover) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - - const int kNumEntries = 2000; - CacheAddr entries[kNumEntries]; - - int seed = static_cast<int>(Time::Now().ToInternalValue()); - srand(seed); - for (int i = 0; i < kNumEntries; i++) { - Addr address(0); - int size = (rand() % 4) + 1; - EXPECT_TRUE(files.CreateBlock(RANKINGS, size, &address)); - entries[i] = address.value(); - } - - for (int i = 0; i < kNumEntries; i++) { - int source1 = rand() % kNumEntries; - int source2 = rand() % kNumEntries; - CacheAddr temp = entries[source1]; - entries[source1] = entries[source2]; - entries[source2] = temp; - } - - for (int i = 0; i < kNumEntries / 2; i++) { - Addr address(entries[i]); - files.DeleteBlock(address, false); - } - - // At this point, there are kNumEntries / 2 entries on the file, randomly - // distributed both on location and size. - - Addr address(entries[kNumEntries / 2]); - MappedFile* file = files.GetFile(address); - ASSERT_TRUE(NULL != file); - - BlockFileHeader* header = - reinterpret_cast<BlockFileHeader*>(file->buffer()); - ASSERT_TRUE(NULL != header); - - ASSERT_EQ(0, header->updating); - - int max_entries = header->max_entries; - int empty_1 = header->empty[0]; - int empty_2 = header->empty[1]; - int empty_3 = header->empty[2]; - int empty_4 = header->empty[3]; - - // Corrupt the file. - header->max_entries = header->empty[0] = 0; - header->empty[1] = header->empty[2] = header->empty[3] = 0; - header->updating = -1; - - files.CloseFiles(); - - ASSERT_TRUE(files.Init(false)); - - // The file must have been fixed. - file = files.GetFile(address); - ASSERT_TRUE(NULL != file); - - header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - ASSERT_TRUE(NULL != header); - - ASSERT_EQ(0, header->updating); - - EXPECT_EQ(max_entries, header->max_entries); - EXPECT_EQ(empty_1, header->empty[0]); - EXPECT_EQ(empty_2, header->empty[1]); - EXPECT_EQ(empty_3, header->empty[2]); - EXPECT_EQ(empty_4, header->empty[3]); -} - -// Handling of truncated files. -TEST_F(DiskCacheTest, BlockFiles_ZeroSizeFile) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - - base::FilePath filename = files.Name(0); - files.CloseFiles(); - // Truncate one of the files. - { - scoped_refptr<File> file(new File); - ASSERT_TRUE(file->Init(filename)); - EXPECT_TRUE(file->SetLength(0)); - } - - // Initializing should fail, not crash. - ASSERT_FALSE(files.Init(false)); -} - -// Handling of truncated files (non empty). -TEST_F(DiskCacheTest, BlockFiles_TruncatedFile) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - Addr address; - EXPECT_TRUE(files.CreateBlock(RANKINGS, 2, &address)); - - base::FilePath filename = files.Name(0); - files.CloseFiles(); - // Truncate one of the files. - { - scoped_refptr<File> file(new File); - ASSERT_TRUE(file->Init(filename)); - EXPECT_TRUE(file->SetLength(15000)); - } - - // Initializing should fail, not crash. - ASSERT_FALSE(files.Init(false)); -} - -// Tests detection of out of sync counters. -TEST_F(DiskCacheTest, BlockFiles_Counters) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - - // Create a block of size 2. - Addr address(0); - EXPECT_TRUE(files.CreateBlock(RANKINGS, 2, &address)); - - MappedFile* file = files.GetFile(address); - ASSERT_TRUE(NULL != file); - - BlockFileHeader* header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - ASSERT_TRUE(NULL != header); - ASSERT_EQ(0, header->updating); - - // Alter the counters so that the free space doesn't add up. - header->empty[2] = 50; // 50 free blocks of size 3. - files.CloseFiles(); - - ASSERT_TRUE(files.Init(false)); - file = files.GetFile(address); - ASSERT_TRUE(NULL != file); - header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - ASSERT_TRUE(NULL != header); - - // The file must have been fixed. - ASSERT_EQ(0, header->empty[2]); - - // Change the number of entries. - header->num_entries = 3; - header->updating = 1; - files.CloseFiles(); - - ASSERT_TRUE(files.Init(false)); - file = files.GetFile(address); - ASSERT_TRUE(NULL != file); - header = reinterpret_cast<BlockFileHeader*>(file->buffer()); - ASSERT_TRUE(NULL != header); - - // The file must have been "fixed". - ASSERT_EQ(2, header->num_entries); - - // Change the number of entries. - header->num_entries = -1; - header->updating = 1; - files.CloseFiles(); - - // Detect the error. - ASSERT_FALSE(files.Init(false)); -} - -// An invalid file can be detected after init. -TEST_F(DiskCacheTest, BlockFiles_InvalidFile) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); - - // Let's access block 10 of file 5. (There is no file). - Addr addr(BLOCK_256, 1, 5, 10); - EXPECT_TRUE(NULL == files.GetFile(addr)); - - // Let's create an invalid file. - base::FilePath filename(files.Name(5)); - char header[kBlockHeaderSize]; - memset(header, 'a', kBlockHeaderSize); - EXPECT_EQ(kBlockHeaderSize, - file_util::WriteFile(filename, header, kBlockHeaderSize)); - - EXPECT_TRUE(NULL == files.GetFile(addr)); - - // The file should not have been changed (it is still invalid). - EXPECT_TRUE(NULL == files.GetFile(addr)); -} - -// Tests that we generate the correct file stats. -TEST_F(DiskCacheTest, BlockFiles_Stats) { - ASSERT_TRUE(CopyTestCache("remove_load1")); - - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(false)); - int used, load; - - files.GetFileStats(0, &used, &load); - EXPECT_EQ(101, used); - EXPECT_EQ(9, load); - - files.GetFileStats(1, &used, &load); - EXPECT_EQ(203, used); - EXPECT_EQ(19, load); - - files.GetFileStats(2, &used, &load); - EXPECT_EQ(0, used); - EXPECT_EQ(0, load); -} - // Tests that we add and remove blocks correctly. -TEST_F(DiskCacheTest, AllocationMap) { - ASSERT_TRUE(CleanupCacheDir()); - ASSERT_TRUE(file_util::CreateDirectory(cache_path_)); +TEST(DiskCacheBlockBitmaps, V3AllocationMap) { + disk_cache::BlockBitmaps block_bitmaps; + disk_cache::BlockFilesBitmaps bitmaps; + + const int kNumHeaders = 10; + disk_cache::BlockFileHeader headers[kNumHeaders]; + for (int i = 0; i < kNumHeaders; i++) { + memset(&headers[i], 0, sizeof(headers[i])); + headers[i].magic = disk_cache::kBlockMagic; + headers[i].version = disk_cache::kBlockCurrentVersion; + headers[i].this_file = static_cast<int16>(i); + headers[i].empty[3] = 200; + headers[i].max_entries = 800; + bitmaps.push_back(disk_cache::BlockHeader(&headers[i])); + } - BlockFiles files(cache_path_); - ASSERT_TRUE(files.Init(true)); + block_bitmaps.Init(bitmaps); // Create a bunch of entries. const int kSize = 100; - Addr address[kSize]; + disk_cache::Addr address[kSize]; for (int i = 0; i < kSize; i++) { SCOPED_TRACE(i); int block_size = i % 4 + 1; - EXPECT_TRUE(files.CreateBlock(BLOCK_1K, block_size, &address[i])); - EXPECT_EQ(BLOCK_1K, address[i].file_type()); + ASSERT_TRUE(block_bitmaps.CreateBlock(disk_cache::BLOCK_1K, block_size, + &address[i])); + EXPECT_EQ(disk_cache::BLOCK_1K, address[i].file_type()); EXPECT_EQ(block_size, address[i].num_blocks()); int start = address[i].start_block(); + + // Verify that the allocated entry doesn't cross a 4 block boundary. EXPECT_EQ(start / 4, (start + block_size - 1) / 4); } for (int i = 0; i < kSize; i++) { SCOPED_TRACE(i); - EXPECT_TRUE(files.IsValid(address[i])); + EXPECT_TRUE(block_bitmaps.IsValid(address[i])); } // The first part of the allocation map should be completely filled. We used - // 10 bits per each four entries, so 250 bits total. - BlockFileHeader* header = - reinterpret_cast<BlockFileHeader*>(files.GetFile(address[0])->buffer()); - uint8* buffer = reinterpret_cast<uint8*>(&header->allocation_map); - for (int i =0; i < 29; i++) { + // 10 bits per each of four entries, so 250 bits total. All entries should go + // to the third file. + uint8* buffer = reinterpret_cast<uint8*>(&headers[2].allocation_map); + for (int i = 0; i < 29; i++) { SCOPED_TRACE(i); EXPECT_EQ(0xff, buffer[i]); } for (int i = 0; i < kSize; i++) { SCOPED_TRACE(i); - files.DeleteBlock(address[i], false); + block_bitmaps.DeleteBlock(address[i]); } // The allocation map should be empty. @@ -346,5 +68,3 @@ TEST_F(DiskCacheTest, AllocationMap) { EXPECT_EQ(0, buffer[i]); } } - -} // namespace disk_cache diff --git a/chromium/net/dns/address_sorter_posix_unittest.cc b/chromium/net/dns/address_sorter_posix_unittest.cc index 96cbfc6fcb0..c4517379957 100644 --- a/chromium/net/dns/address_sorter_posix_unittest.cc +++ b/chromium/net/dns/address_sorter_posix_unittest.cc @@ -10,6 +10,8 @@ #include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/socket/client_socket_factory.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" #include "net/udp/datagram_client_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -90,27 +92,27 @@ class TestSocketFactory : public ClientSocketFactory { TestSocketFactory() {} virtual ~TestSocketFactory() {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType, const RandIntCallback&, NetLog*, const NetLog::Source&) OVERRIDE { - return new TestUDPClientSocket(&mapping_); + return scoped_ptr<DatagramClientSocket>(new TestUDPClientSocket(&mapping_)); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList&, NetLog*, const NetLog::Source&) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle*, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle>, const HostPortPair&, const SSLConfig&, const SSLClientSocketContext&) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { NOTIMPLEMENTED(); diff --git a/chromium/net/dns/dns_client.cc b/chromium/net/dns/dns_client.cc index 976f1533905..9e29ca4265d 100644 --- a/chromium/net/dns/dns_client.cc +++ b/chromium/net/dns/dns_client.cc @@ -27,7 +27,7 @@ class DnsClientImpl : public DnsClient { virtual void SetConfig(const DnsConfig& config) OVERRIDE { factory_.reset(); session_ = NULL; - if (config.IsValid()) { + if (config.IsValid() && !config.unhandled_options) { ClientSocketFactory* factory = ClientSocketFactory::GetDefaultFactory(); scoped_ptr<DnsSocketPool> socket_pool( config.randomize_ports ? DnsSocketPool::CreateDefault(factory) diff --git a/chromium/net/dns/dns_client.h b/chromium/net/dns/dns_client.h index 650c7d0d416..0484d44c0c6 100644 --- a/chromium/net/dns/dns_client.h +++ b/chromium/net/dns/dns_client.h @@ -22,7 +22,8 @@ class NET_EXPORT DnsClient { public: virtual ~DnsClient() {} - // Creates a new DnsTransactionFactory according to the new |config|. + // Destroys the current DnsTransactionFactory and creates a new one + // according to |config|, unless it is invalid or has |unhandled_options|. virtual void SetConfig(const DnsConfig& config) = 0; // Returns NULL if the current config is not valid. diff --git a/chromium/net/dns/dns_config_service.cc b/chromium/net/dns/dns_config_service.cc index ea8a3421cd2..66131825571 100644 --- a/chromium/net/dns/dns_config_service.cc +++ b/chromium/net/dns/dns_config_service.cc @@ -14,13 +14,15 @@ namespace net { // Default values are taken from glibc resolv.h except timeout which is set to // |kDnsTimeoutSeconds|. DnsConfig::DnsConfig() - : append_to_multi_label_name(true), + : unhandled_options(false), + append_to_multi_label_name(true), randomize_ports(false), ndots(1), timeout(base::TimeDelta::FromSeconds(kDnsTimeoutSeconds)), attempts(2), rotate(false), - edns0(false) {} + edns0(false), + use_local_ipv6(false) {} DnsConfig::~DnsConfig() {} @@ -31,23 +33,27 @@ bool DnsConfig::Equals(const DnsConfig& d) const { bool DnsConfig::EqualsIgnoreHosts(const DnsConfig& d) const { return (nameservers == d.nameservers) && (search == d.search) && + (unhandled_options == d.unhandled_options) && (append_to_multi_label_name == d.append_to_multi_label_name) && (ndots == d.ndots) && (timeout == d.timeout) && (attempts == d.attempts) && (rotate == d.rotate) && - (edns0 == d.edns0); + (edns0 == d.edns0) && + (use_local_ipv6 == d.use_local_ipv6); } void DnsConfig::CopyIgnoreHosts(const DnsConfig& d) { nameservers = d.nameservers; search = d.search; + unhandled_options = d.unhandled_options; append_to_multi_label_name = d.append_to_multi_label_name; ndots = d.ndots; timeout = d.timeout; attempts = d.attempts; rotate = d.rotate; edns0 = d.edns0; + use_local_ipv6 = d.use_local_ipv6; } base::Value* DnsConfig::ToValue() const { @@ -63,12 +69,14 @@ base::Value* DnsConfig::ToValue() const { list->Append(new base::StringValue(search[i])); dict->Set("search", list); + dict->SetBoolean("unhandled_options", unhandled_options); dict->SetBoolean("append_to_multi_label_name", append_to_multi_label_name); dict->SetInteger("ndots", ndots); dict->SetDouble("timeout", timeout.InSecondsF()); dict->SetInteger("attempts", attempts); dict->SetBoolean("rotate", rotate); dict->SetBoolean("edns0", edns0); + dict->SetBoolean("use_local_ipv6", use_local_ipv6); dict->SetInteger("num_hosts", hosts.size()); return dict; diff --git a/chromium/net/dns/dns_config_service.h b/chromium/net/dns/dns_config_service.h index 4babb9e7d37..7386e601790 100644 --- a/chromium/net/dns/dns_config_service.h +++ b/chromium/net/dns/dns_config_service.h @@ -59,6 +59,10 @@ struct NET_EXPORT_PRIVATE DnsConfig { DnsHosts hosts; + // True if there are options set in the system configuration that are not yet + // supported by DnsClient. + bool unhandled_options; + // AppendToMultiLabelName: is suffix search performed for multi-label names? // True, except on Windows where it can be configured. bool append_to_multi_label_name; @@ -79,6 +83,11 @@ struct NET_EXPORT_PRIVATE DnsConfig { bool rotate; // Enable EDNS0 extensions. bool edns0; + + // Indicates system configuration uses local IPv6 connectivity, e.g., + // DirectAccess. This is exposed for HostResolver to skip IPv6 probes, + // as it may cause them to return incorrect results. + bool use_local_ipv6; }; diff --git a/chromium/net/dns/dns_config_service_posix.cc b/chromium/net/dns/dns_config_service_posix.cc index ff2295e704a..baf917284c7 100644 --- a/chromium/net/dns/dns_config_service_posix.cc +++ b/chromium/net/dns/dns_config_service_posix.cc @@ -10,6 +10,7 @@ #include "base/bind.h" #include "base/files/file_path.h" #include "base/files/file_path_watcher.h" +#include "base/lazy_instance.h" #include "base/memory/scoped_ptr.h" #include "base/metrics/histogram.h" #include "base/time/time.h" @@ -20,6 +21,69 @@ #include "net/dns/notify_watcher_mac.h" #include "net/dns/serial_worker.h" +#if defined(OS_MACOSX) +#include <dlfcn.h> + +#include "third_party/apple_apsl/dnsinfo.h" + +namespace { + +// dnsinfo symbols are available via libSystem.dylib, but can also be present in +// SystemConfiguration.framework. To avoid confusion, load them explicitly from +// libSystem.dylib. +class DnsInfoApi { + public: + typedef const char* (*dns_configuration_notify_key_t)(); + typedef dns_config_t* (*dns_configuration_copy_t)(); + typedef void (*dns_configuration_free_t)(dns_config_t*); + + DnsInfoApi() + : dns_configuration_notify_key(NULL), + dns_configuration_copy(NULL), + dns_configuration_free(NULL) { + handle_ = dlopen("/usr/lib/libSystem.dylib", + RTLD_LAZY | RTLD_NOLOAD); + if (!handle_) + return; + dns_configuration_notify_key = + reinterpret_cast<dns_configuration_notify_key_t>( + dlsym(handle_, "dns_configuration_notify_key")); + dns_configuration_copy = + reinterpret_cast<dns_configuration_copy_t>( + dlsym(handle_, "dns_configuration_copy")); + dns_configuration_free = + reinterpret_cast<dns_configuration_free_t>( + dlsym(handle_, "dns_configuration_free")); + } + + ~DnsInfoApi() { + if (handle_) + dlclose(handle_); + } + + dns_configuration_notify_key_t dns_configuration_notify_key; + dns_configuration_copy_t dns_configuration_copy; + dns_configuration_free_t dns_configuration_free; + + private: + void* handle_; +}; + +const DnsInfoApi& GetDnsInfoApi() { + static base::LazyInstance<DnsInfoApi>::Leaky api = LAZY_INSTANCE_INITIALIZER; + return api.Get(); +} + +struct DnsConfigTDeleter { + inline void operator()(dns_config_t* ptr) const { + if (GetDnsInfoApi().dns_configuration_free) + GetDnsInfoApi().dns_configuration_free(ptr); + } +}; + +} // namespace +#endif // defined(OS_MACOSX) + namespace net { #if !defined(OS_ANDROID) @@ -31,14 +95,13 @@ const base::FilePath::CharType* kFilePathHosts = FILE_PATH_LITERAL("/etc/hosts"); #if defined(OS_MACOSX) -// From 10.7.3 configd-395.10/dnsinfo/dnsinfo.h -static const char* kDnsNotifyKey = - "com.apple.system.SystemConfiguration.dns_configuration"; - class ConfigWatcher { public: bool Watch(const base::Callback<void(bool succeeded)>& callback) { - return watcher_.Watch(kDnsNotifyKey, callback); + if (!GetDnsInfoApi().dns_configuration_notify_key) + return false; + return watcher_.Watch(GetDnsInfoApi().dns_configuration_notify_key(), + callback); } private: @@ -76,6 +139,7 @@ class ConfigWatcher { ConfigParsePosixResult ReadDnsConfig(DnsConfig* config) { ConfigParsePosixResult result; + config->unhandled_options = false; #if defined(OS_OPENBSD) // Note: res_ninit in glibc always returns 0 and sets RES_INIT. // res_init behaves the same way. @@ -100,6 +164,32 @@ ConfigParsePosixResult ReadDnsConfig(DnsConfig* config) { res_nclose(&res); #endif #endif + +#if defined(OS_MACOSX) + if (!GetDnsInfoApi().dns_configuration_copy) + return CONFIG_PARSE_POSIX_NO_DNSINFO; + scoped_ptr<dns_config_t, DnsConfigTDeleter> dns_config( + GetDnsInfoApi().dns_configuration_copy()); + if (!dns_config) + return CONFIG_PARSE_POSIX_NO_DNSINFO; + + // TODO(szym): Parse dns_config_t for resolvers rather than res_state. + // DnsClient can't handle domain-specific unscoped resolvers. + unsigned num_resolvers = 0; + for (int i = 0; i < dns_config->n_resolver; ++i) { + dns_resolver_t* resolver = dns_config->resolver[i]; + if (!resolver->n_nameserver) + continue; + if (resolver->options && !strcmp(resolver->options, "mdns")) + continue; + ++num_resolvers; + } + if (num_resolvers > 1) { + LOG(WARNING) << "dns_config has unhandled options!"; + config->unhandled_options = true; + return CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS; + } +#endif // defined(OS_MACOSX) // Override timeout value to match default setting on Windows. config->timeout = base::TimeDelta::FromSeconds(kDnsTimeoutSeconds); return result; @@ -172,7 +262,18 @@ class DnsConfigServicePosix::ConfigReader : public SerialWorker { virtual void DoWork() OVERRIDE { base::TimeTicks start_time = base::TimeTicks::Now(); ConfigParsePosixResult result = ReadDnsConfig(&dns_config_); - success_ = (result == CONFIG_PARSE_POSIX_OK); + switch (result) { + case CONFIG_PARSE_POSIX_MISSING_OPTIONS: + case CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS: + DCHECK(dns_config_.unhandled_options); + // Fall through. + case CONFIG_PARSE_POSIX_OK: + success_ = true; + break; + default: + success_ = false; + break; + } UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ConfigParsePosix", result, CONFIG_PARSE_POSIX_MAX); UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigParseResult", success_); @@ -358,12 +459,16 @@ ConfigParsePosixResult ConvertResStateToDnsConfig(const struct __res_state& res, // The current implementation assumes these options are set. They normally // cannot be overwritten by /etc/resolv.conf unsigned kRequiredOptions = RES_RECURSE | RES_DEFNAMES | RES_DNSRCH; - if ((res.options & kRequiredOptions) != kRequiredOptions) + if ((res.options & kRequiredOptions) != kRequiredOptions) { + dns_config->unhandled_options = true; return CONFIG_PARSE_POSIX_MISSING_OPTIONS; + } unsigned kUnhandledOptions = RES_USEVC | RES_IGNTC | RES_USE_DNSSEC; - if (res.options & kUnhandledOptions) + if (res.options & kUnhandledOptions) { + dns_config->unhandled_options = true; return CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS; + } if (dns_config->nameservers.empty()) return CONFIG_PARSE_POSIX_NO_NAMESERVERS; diff --git a/chromium/net/dns/dns_config_service_posix.h b/chromium/net/dns/dns_config_service_posix.h index 95a4377d932..be19ab90b08 100644 --- a/chromium/net/dns/dns_config_service_posix.h +++ b/chromium/net/dns/dns_config_service_posix.h @@ -53,6 +53,7 @@ enum ConfigParsePosixResult { CONFIG_PARSE_POSIX_NO_NAMESERVERS, CONFIG_PARSE_POSIX_MISSING_OPTIONS, CONFIG_PARSE_POSIX_UNHANDLED_OPTIONS, + CONFIG_PARSE_POSIX_NO_DNSINFO, CONFIG_PARSE_POSIX_MAX // Bounding values for enumeration. }; diff --git a/chromium/net/dns/dns_config_service_win.cc b/chromium/net/dns/dns_config_service_win.cc index 6d12155d8cc..fe97b74f73e 100644 --- a/chromium/net/dns/dns_config_service_win.cc +++ b/chromium/net/dns/dns_config_service_win.cc @@ -43,8 +43,19 @@ namespace { // Interval between retries to parse config. Used only until parsing succeeds. const int kRetryIntervalSeconds = 5; +// Registry key paths. +const wchar_t* const kTcpipPath = + L"SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"; +const wchar_t* const kTcpip6Path = + L"SYSTEM\\CurrentControlSet\\Services\\Tcpip6\\Parameters"; +const wchar_t* const kDnscachePath = + L"SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters"; +const wchar_t* const kPolicyPath = + L"SOFTWARE\\Policies\\Microsoft\\Windows NT\\DNSClient"; const wchar_t* const kPrimaryDnsSuffixPath = L"SOFTWARE\\Policies\\Microsoft\\System\\DNSClient"; +const wchar_t* const kNRPTPath = + L"SOFTWARE\\Policies\\Microsoft\\Windows NT\\DNSClient\\DnsPolicyConfig"; enum HostsParseWinResult { HOSTS_PARSE_WIN_OK = 0, @@ -198,6 +209,10 @@ ConfigParseWinResult ReadSystemSettings(DnsSystemSettings* settings) { &settings->primary_dns_suffix)) { return CONFIG_PARSE_WIN_READ_PRIMARY_SUFFIX; } + + base::win::RegistryKeyIterator nrpt_rules(HKEY_LOCAL_MACHINE, kNRPTPath); + settings->have_name_resolution_policy = (nrpt_rules.SubkeyCount() > 0); + return CONFIG_PARSE_WIN_OK; } @@ -330,8 +345,7 @@ bool IsStatelessDiscoveryAddress(const IPAddressNumber& address) { address.begin()) && (address.back() < 4); } -} // namespace - +// Returns the path to the HOSTS file. base::FilePath GetHostsPath() { TCHAR buffer[MAX_PATH]; UINT rc = GetSystemDirectory(buffer, MAX_PATH); @@ -340,6 +354,92 @@ base::FilePath GetHostsPath() { FILE_PATH_LITERAL("drivers\\etc\\hosts")); } +void ConfigureSuffixSearch(const DnsSystemSettings& settings, + DnsConfig* config) { + // SearchList takes precedence, so check it first. + if (settings.policy_search_list.set) { + std::vector<std::string> search; + if (ParseSearchList(settings.policy_search_list.value, &search)) { + config->search.swap(search); + return; + } + // Even if invalid, the policy disables the user-specified setting below. + } else if (settings.tcpip_search_list.set) { + std::vector<std::string> search; + if (ParseSearchList(settings.tcpip_search_list.value, &search)) { + config->search.swap(search); + return; + } + } + + // In absence of explicit search list, suffix search is: + // [primary suffix, connection-specific suffix, devolution of primary suffix]. + // Primary suffix can be set by policy (primary_dns_suffix) or + // user setting (tcpip_domain). + // + // The policy (primary_dns_suffix) can be edited via Group Policy Editor + // (gpedit.msc) at Local Computer Policy => Computer Configuration + // => Administrative Template => Network => DNS Client => Primary DNS Suffix. + // + // The user setting (tcpip_domain) can be configurred at Computer Name in + // System Settings + std::string primary_suffix; + if ((settings.primary_dns_suffix.set && + ParseDomainASCII(settings.primary_dns_suffix.value, &primary_suffix)) || + (settings.tcpip_domain.set && + ParseDomainASCII(settings.tcpip_domain.value, &primary_suffix))) { + // Primary suffix goes in front. + config->search.insert(config->search.begin(), primary_suffix); + } else { + return; // No primary suffix, hence no devolution. + } + + // Devolution is determined by precedence: policy > dnscache > tcpip. + // |enabled|: UseDomainNameDevolution and |level|: DomainNameDevolutionLevel + // are overridden independently. + DnsSystemSettings::DevolutionSetting devolution = settings.policy_devolution; + + if (!devolution.enabled.set) + devolution.enabled = settings.dnscache_devolution.enabled; + if (!devolution.enabled.set) + devolution.enabled = settings.tcpip_devolution.enabled; + if (devolution.enabled.set && (devolution.enabled.value == 0)) + return; // Devolution disabled. + + // By default devolution is enabled. + + if (!devolution.level.set) + devolution.level = settings.dnscache_devolution.level; + if (!devolution.level.set) + devolution.level = settings.tcpip_devolution.level; + + // After the recent update, Windows will try to determine a safe default + // value by comparing the forest root domain (FRD) to the primary suffix. + // See http://support.microsoft.com/kb/957579 for details. + // For now, if the level is not set, we disable devolution, assuming that + // we will fallback to the system getaddrinfo anyway. This might cause + // performance loss for resolutions which depend on the system default + // devolution setting. + // + // If the level is explicitly set below 2, devolution is disabled. + if (!devolution.level.set || devolution.level.value < 2) + return; // Devolution disabled. + + // Devolve the primary suffix. This naive logic matches the observed + // behavior (see also ParseSearchList). If a suffix is not valid, it will be + // discarded when the fully-qualified name is converted to DNS format. + + unsigned num_dots = std::count(primary_suffix.begin(), + primary_suffix.end(), '.'); + + for (size_t offset = 0; num_dots >= devolution.level.value; --num_dots) { + offset = primary_suffix.find('.', offset + 1); + config->search.push_back(primary_suffix.substr(offset + 1)); + } +} + +} // namespace + bool ParseSearchList(const base::string16& value, std::vector<std::string>* output) { DCHECK(output); @@ -429,87 +529,16 @@ ConfigParseWinResult ConvertSettingsToDnsConfig( (settings.append_to_multi_label_name.value != 0); } - // SearchList takes precedence, so check it first. - if (settings.policy_search_list.set) { - std::vector<std::string> search; - if (ParseSearchList(settings.policy_search_list.value, &search)) { - config->search.swap(search); - return CONFIG_PARSE_WIN_OK; - } - // Even if invalid, the policy disables the user-specified setting below. - } else if (settings.tcpip_search_list.set) { - std::vector<std::string> search; - if (ParseSearchList(settings.tcpip_search_list.value, &search)) { - config->search.swap(search); - return CONFIG_PARSE_WIN_OK; - } + ConfigParseWinResult result = CONFIG_PARSE_WIN_OK; + if (settings.have_name_resolution_policy) { + config->unhandled_options = true; + // TODO(szym): only set this to true if NRPT has DirectAccess rules. + config->use_local_ipv6 = true; + result = CONFIG_PARSE_WIN_UNHANDLED_OPTIONS; } - // In absence of explicit search list, suffix search is: - // [primary suffix, connection-specific suffix, devolution of primary suffix]. - // Primary suffix can be set by policy (primary_dns_suffix) or - // user setting (tcpip_domain). - // - // The policy (primary_dns_suffix) can be edited via Group Policy Editor - // (gpedit.msc) at Local Computer Policy => Computer Configuration - // => Administrative Template => Network => DNS Client => Primary DNS Suffix. - // - // The user setting (tcpip_domain) can be configurred at Computer Name in - // System Settings - std::string primary_suffix; - if ((settings.primary_dns_suffix.set && - ParseDomainASCII(settings.primary_dns_suffix.value, &primary_suffix)) || - (settings.tcpip_domain.set && - ParseDomainASCII(settings.tcpip_domain.value, &primary_suffix))) { - // Primary suffix goes in front. - config->search.insert(config->search.begin(), primary_suffix); - } else { - return CONFIG_PARSE_WIN_OK; // No primary suffix, hence no devolution. - } - - // Devolution is determined by precedence: policy > dnscache > tcpip. - // |enabled|: UseDomainNameDevolution and |level|: DomainNameDevolutionLevel - // are overridden independently. - DnsSystemSettings::DevolutionSetting devolution = settings.policy_devolution; - - if (!devolution.enabled.set) - devolution.enabled = settings.dnscache_devolution.enabled; - if (!devolution.enabled.set) - devolution.enabled = settings.tcpip_devolution.enabled; - if (devolution.enabled.set && (devolution.enabled.value == 0)) - return CONFIG_PARSE_WIN_OK; // Devolution disabled. - - // By default devolution is enabled. - - if (!devolution.level.set) - devolution.level = settings.dnscache_devolution.level; - if (!devolution.level.set) - devolution.level = settings.tcpip_devolution.level; - - // After the recent update, Windows will try to determine a safe default - // value by comparing the forest root domain (FRD) to the primary suffix. - // See http://support.microsoft.com/kb/957579 for details. - // For now, if the level is not set, we disable devolution, assuming that - // we will fallback to the system getaddrinfo anyway. This might cause - // performance loss for resolutions which depend on the system default - // devolution setting. - // - // If the level is explicitly set below 2, devolution is disabled. - if (!devolution.level.set || devolution.level.value < 2) - return CONFIG_PARSE_WIN_OK; // Devolution disabled. - - // Devolve the primary suffix. This naive logic matches the observed - // behavior (see also ParseSearchList). If a suffix is not valid, it will be - // discarded when the fully-qualified name is converted to DNS format. - - unsigned num_dots = std::count(primary_suffix.begin(), - primary_suffix.end(), '.'); - - for (size_t offset = 0; num_dots >= devolution.level.value; --num_dots) { - offset = primary_suffix.find('.', offset + 1); - config->search.push_back(primary_suffix.substr(offset + 1)); - } - return CONFIG_PARSE_WIN_OK; + ConfigureSuffixSearch(settings, config); + return result; } // Watches registry and HOSTS file for changes. Must live on a thread which @@ -606,7 +635,8 @@ class DnsConfigServiceWin::ConfigReader : public SerialWorker { ConfigParseWinResult result = ReadSystemSettings(&settings); if (result == CONFIG_PARSE_WIN_OK) result = ConvertSettingsToDnsConfig(settings, &dns_config_); - success_ = (result == CONFIG_PARSE_WIN_OK); + success_ = (result == CONFIG_PARSE_WIN_OK || + result == CONFIG_PARSE_WIN_UNHANDLED_OPTIONS); UMA_HISTOGRAM_ENUMERATION("AsyncDNS.ConfigParseWin", result, CONFIG_PARSE_WIN_MAX); UMA_HISTOGRAM_BOOLEAN("AsyncDNS.ConfigParseResult", success_); diff --git a/chromium/net/dns/dns_config_service_win.h b/chromium/net/dns/dns_config_service_win.h index 06fc0d9663b..9503dc8593f 100644 --- a/chromium/net/dns/dns_config_service_win.h +++ b/chromium/net/dns/dns_config_service_win.h @@ -34,19 +34,6 @@ namespace net { namespace internal { -// Registry key paths. -const wchar_t* const kTcpipPath = - L"SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters"; -const wchar_t* const kTcpip6Path = - L"SYSTEM\\CurrentControlSet\\Services\\Tcpip6\\Parameters"; -const wchar_t* const kDnscachePath = - L"SYSTEM\\CurrentControlSet\\Services\\Dnscache\\Parameters"; -const wchar_t* const kPolicyPath = - L"SOFTWARE\\Policies\\Microsoft\\Windows NT\\DNSClient"; - -// Returns the path to the HOSTS file. -base::FilePath GetHostsPath(); - // Parses |value| as search list (comma-delimited list of domain names) from // a registry key and stores it in |out|. Returns true on success. Empty // entries (e.g., "chromium.org,,org") terminate the list. Non-ascii hostnames @@ -98,6 +85,10 @@ struct NET_EXPORT_PRIVATE DnsSystemSettings { // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\AppendToMultiLabelName RegDword append_to_multi_label_name; + + // True when the Name Resolution Policy Table (NRPT) has at least one rule: + // SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\DnsPolicyConfig\Rule* + bool have_name_resolution_policy; }; enum ConfigParseWinResult { @@ -113,6 +104,7 @@ enum ConfigParseWinResult { CONFIG_PARSE_WIN_READ_PRIMARY_SUFFIX, CONFIG_PARSE_WIN_BAD_ADDRESS, CONFIG_PARSE_WIN_NO_NAMESERVERS, + CONFIG_PARSE_WIN_UNHANDLED_OPTIONS, CONFIG_PARSE_WIN_MAX // Bounding values for enumeration. }; diff --git a/chromium/net/dns/dns_config_service_win_unittest.cc b/chromium/net/dns/dns_config_service_win_unittest.cc index b28b8e9554a..3f3e4ed1e37 100644 --- a/chromium/net/dns/dns_config_service_win_unittest.cc +++ b/chromium/net/dns/dns_config_service_win_unittest.cc @@ -4,6 +4,7 @@ #include "net/dns/dns_config_service_win.h" +#include "base/basictypes.h" #include "base/logging.h" #include "base/win/windows_version.h" #include "net/dns/dns_protocol.h" @@ -420,10 +421,46 @@ TEST(DnsConfigServiceWinTest, AppendToMultiLabelName) { DnsConfig config; EXPECT_EQ(internal::CONFIG_PARSE_WIN_OK, internal::ConvertSettingsToDnsConfig(settings, &config)); - EXPECT_EQ(config.append_to_multi_label_name, t.expected_output); + EXPECT_EQ(t.expected_output, config.append_to_multi_label_name); } } +// Setting have_name_resolution_policy_table should set unhandled_options. +TEST(DnsConfigServiceWinTest, HaveNRPT) { + AdapterInfo infos[2] = { + { IF_TYPE_USB, IfOperStatusUp, L"connection.suffix", { "1.0.0.1" } }, + { 0 }, + }; + + const struct TestCase { + bool have_nrpt; + bool unhandled_options; + internal::ConfigParseWinResult result; + } cases[] = { + { false, false, internal::CONFIG_PARSE_WIN_OK }, + { true, true, internal::CONFIG_PARSE_WIN_UNHANDLED_OPTIONS }, + }; + + for (size_t i = 0; i < arraysize(cases); ++i) { + const TestCase& t = cases[i]; + internal::DnsSystemSettings settings = { + CreateAdapterAddresses(infos), + { false }, { false }, { false }, { false }, + { { false }, { false } }, + { { false }, { false } }, + { { false }, { false } }, + { false }, + t.have_nrpt, + }; + DnsConfig config; + EXPECT_EQ(t.result, + internal::ConvertSettingsToDnsConfig(settings, &config)); + EXPECT_EQ(t.unhandled_options, config.unhandled_options); + EXPECT_EQ(t.have_nrpt, config.use_local_ipv6); + } +} + + } // namespace } // namespace net diff --git a/chromium/net/dns/dns_hosts.cc b/chromium/net/dns/dns_hosts.cc index 852d35c8bb4..3edea2a7abc 100644 --- a/chromium/net/dns/dns_hosts.cc +++ b/chromium/net/dns/dns_hosts.cc @@ -158,7 +158,7 @@ bool ParseHostsFile(const base::FilePath& path, DnsHosts* dns_hosts) { return false; std::string contents; - if (!file_util::ReadFileToString(path, &contents)) + if (!base::ReadFileToString(path, &contents)) return false; ParseHosts(contents, dns_hosts); diff --git a/chromium/net/dns/dns_session_unittest.cc b/chromium/net/dns/dns_session_unittest.cc index 46627069f66..ed726f23234 100644 --- a/chromium/net/dns/dns_session_unittest.cc +++ b/chromium/net/dns/dns_session_unittest.cc @@ -14,6 +14,8 @@ #include "net/dns/dns_protocol.h" #include "net/dns/dns_socket_pool.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" +#include "net/socket/stream_socket.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -24,26 +26,26 @@ class TestClientSocketFactory : public ClientSocketFactory { public: virtual ~TestClientSocketFactory(); - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog*, const NetLog::Source&) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -179,7 +181,8 @@ bool DnsSessionTest::ExpectEvent(const PoolEvent& expected) { return true; } -DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket( +scoped_ptr<DatagramClientSocket> +TestClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, @@ -188,9 +191,10 @@ DatagramClientSocket* TestClientSocketFactory::CreateDatagramClientSocket( // simplest SocketDataProvider with no data supplied. SocketDataProvider* data_provider = new StaticSocketDataProvider(); data_providers_.push_back(data_provider); - MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } TestClientSocketFactory::~TestClientSocketFactory() { diff --git a/chromium/net/dns/dns_socket_pool.cc b/chromium/net/dns/dns_socket_pool.cc index 64570fca8fc..7a7ecd6ee8f 100644 --- a/chromium/net/dns/dns_socket_pool.cc +++ b/chromium/net/dns/dns_socket_pool.cc @@ -76,8 +76,8 @@ scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket( scoped_ptr<DatagramClientSocket> socket; NetLog::Source no_source; - socket.reset(socket_factory_->CreateDatagramClientSocket( - kBindType, base::Bind(&base::RandInt), net_log_, no_source)); + socket = socket_factory_->CreateDatagramClientSocket( + kBindType, base::Bind(&base::RandInt), net_log_, no_source); if (socket.get()) { int rv = socket->Connect((*nameservers_)[server_index]); diff --git a/chromium/net/dns/dns_test_util.cc b/chromium/net/dns/dns_test_util.cc index 37bf855dd10..63014213faa 100644 --- a/chromium/net/dns/dns_test_util.cc +++ b/chromium/net/dns/dns_test_util.cc @@ -15,9 +15,6 @@ #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/dns/address_sorter.h" -#include "net/dns/dns_client.h" -#include "net/dns/dns_config_service.h" -#include "net/dns/dns_protocol.h" #include "net/dns/dns_query.h" #include "net/dns/dns_response.h" #include "net/dns/dns_transaction.h" @@ -26,6 +23,16 @@ namespace net { namespace { +class MockAddressSorter : public AddressSorter { + public: + virtual ~MockAddressSorter() {} + virtual void Sort(const AddressList& list, + const CallbackType& callback) const OVERRIDE { + // Do nothing. + callback.Run(true, list); + } +}; + // A DnsTransaction which uses MockDnsClientRuleList to determine the response. class MockTransaction : public DnsTransaction, public base::SupportsWeakPtr<MockTransaction> { @@ -38,7 +45,8 @@ class MockTransaction : public DnsTransaction, hostname_(hostname), qtype_(qtype), callback_(callback), - started_(false) { + started_(false), + delayed_(false) { // Find the relevant rule which matches |qtype| and prefix of |hostname|. for (size_t i = 0; i < rules.size(); ++i) { const std::string& prefix = rules[i].prefix; @@ -46,6 +54,7 @@ class MockTransaction : public DnsTransaction, (hostname.size() >= prefix.size()) && (hostname.compare(0, prefix.size(), prefix) == 0)) { result_ = rules[i].result; + delayed_ = rules[i].delay; break; } } @@ -62,11 +71,21 @@ class MockTransaction : public DnsTransaction, virtual void Start() OVERRIDE { EXPECT_FALSE(started_); started_ = true; + if (delayed_) + return; // Using WeakPtr to cleanly cancel when transaction is destroyed. base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); } + void FinishDelayedTransaction() { + EXPECT_TRUE(delayed_); + delayed_ = false; + Finish(); + } + + bool delayed() const { return delayed_; } + private: void Finish() { switch (result_) { @@ -136,14 +155,17 @@ class MockTransaction : public DnsTransaction, const uint16 qtype_; DnsTransactionFactory::CallbackType callback_; bool started_; + bool delayed_; }; +} // namespace // A DnsTransactionFactory which creates MockTransaction. class MockTransactionFactory : public DnsTransactionFactory { public: explicit MockTransactionFactory(const MockDnsClientRuleList& rules) : rules_(rules) {} + virtual ~MockTransactionFactory() {} virtual scoped_ptr<DnsTransaction> CreateTransaction( @@ -151,60 +173,57 @@ class MockTransactionFactory : public DnsTransactionFactory { uint16 qtype, const DnsTransactionFactory::CallbackType& callback, const BoundNetLog&) OVERRIDE { - return scoped_ptr<DnsTransaction>( - new MockTransaction(rules_, hostname, qtype, callback)); + MockTransaction* transaction = + new MockTransaction(rules_, hostname, qtype, callback); + if (transaction->delayed()) + delayed_transactions_.push_back(transaction->AsWeakPtr()); + return scoped_ptr<DnsTransaction>(transaction); + } + + void CompleteDelayedTransactions() { + DelayedTransactionList old_delayed_transactions; + old_delayed_transactions.swap(delayed_transactions_); + for (DelayedTransactionList::iterator it = old_delayed_transactions.begin(); + it != old_delayed_transactions.end(); ++it) { + if (it->get()) + (*it)->FinishDelayedTransaction(); + } } private: - MockDnsClientRuleList rules_; -}; + typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; -class MockAddressSorter : public AddressSorter { - public: - virtual ~MockAddressSorter() {} - virtual void Sort(const AddressList& list, - const CallbackType& callback) const OVERRIDE { - // Do nothing. - callback.Run(true, list); - } + MockDnsClientRuleList rules_; + DelayedTransactionList delayed_transactions_; }; -// MockDnsClient provides MockTransactionFactory. -class MockDnsClient : public DnsClient { - public: - MockDnsClient(const DnsConfig& config, - const MockDnsClientRuleList& rules) - : config_(config), factory_(rules) {} - virtual ~MockDnsClient() {} +MockDnsClient::MockDnsClient(const DnsConfig& config, + const MockDnsClientRuleList& rules) + : config_(config), + factory_(new MockTransactionFactory(rules)), + address_sorter_(new MockAddressSorter()) { +} - virtual void SetConfig(const DnsConfig& config) OVERRIDE { - config_ = config; - } +MockDnsClient::~MockDnsClient() {} - virtual const DnsConfig* GetConfig() const OVERRIDE { - return config_.IsValid() ? &config_ : NULL; - } - - virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE { - return config_.IsValid() ? &factory_ : NULL; - } +void MockDnsClient::SetConfig(const DnsConfig& config) { + config_ = config; +} - virtual AddressSorter* GetAddressSorter() OVERRIDE { - return &address_sorter_; - } +const DnsConfig* MockDnsClient::GetConfig() const { + return config_.IsValid() ? &config_ : NULL; +} - private: - DnsConfig config_; - MockTransactionFactory factory_; - MockAddressSorter address_sorter_; -}; +DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { + return config_.IsValid() ? factory_.get() : NULL; +} -} // namespace +AddressSorter* MockDnsClient::GetAddressSorter() { + return address_sorter_.get(); +} -// static -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, - const MockDnsClientRuleList& rules) { - return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); +void MockDnsClient::CompleteDelayedTransactions() { + factory_->CompleteDelayedTransactions(); } } // namespace net diff --git a/chromium/net/dns/dns_test_util.h b/chromium/net/dns/dns_test_util.h index d447b299c86..d0b8e81b7ed 100644 --- a/chromium/net/dns/dns_test_util.h +++ b/chromium/net/dns/dns_test_util.h @@ -10,6 +10,7 @@ #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" +#include "net/dns/dns_client.h" #include "net/dns/dns_config_service.h" #include "net/dns/dns_protocol.h" @@ -174,7 +175,9 @@ static const int kT3TTL = 0x00000015; // +2 for the CNAME records, +1 for TXT record. static const unsigned kT3RecordCount = arraysize(kT3IpAddresses) + 3; +class AddressSorter; class DnsClient; +class MockTransactionFactory; struct MockDnsClientRule { enum Result { @@ -184,21 +187,43 @@ struct MockDnsClientRule { OK, // Return a response with loopback address. }; + // If |delay| is true, matching transactions will be delayed until triggered + // by the consumer. MockDnsClientRule(const std::string& prefix_arg, uint16 qtype_arg, - Result result_arg) - : result(result_arg), prefix(prefix_arg), qtype(qtype_arg) { } + Result result_arg, + bool delay) + : result(result_arg), prefix(prefix_arg), qtype(qtype_arg), + delay(delay) {} Result result; std::string prefix; uint16 qtype; + bool delay; }; typedef std::vector<MockDnsClientRule> MockDnsClientRuleList; -// Creates mock DnsClient for testing HostResolverImpl. -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, - const MockDnsClientRuleList& rules); +// MockDnsClient provides MockTransactionFactory. +class MockDnsClient : public DnsClient { + public: + MockDnsClient(const DnsConfig& config, const MockDnsClientRuleList& rules); + virtual ~MockDnsClient(); + + // DnsClient interface: + virtual void SetConfig(const DnsConfig& config) OVERRIDE; + virtual const DnsConfig* GetConfig() const OVERRIDE; + virtual DnsTransactionFactory* GetTransactionFactory() OVERRIDE; + virtual AddressSorter* GetAddressSorter() OVERRIDE; + + // Completes all DnsTransactions that were delayed by a rule. + void CompleteDelayedTransactions(); + + private: + DnsConfig config_; + scoped_ptr<MockTransactionFactory> factory_; + scoped_ptr<AddressSorter> address_sorter_; +}; } // namespace net diff --git a/chromium/net/dns/dns_transaction_unittest.cc b/chromium/net/dns/dns_transaction_unittest.cc index f9667eed5f4..7040e44be16 100644 --- a/chromium/net/dns/dns_transaction_unittest.cc +++ b/chromium/net/dns/dns_transaction_unittest.cc @@ -180,21 +180,21 @@ class TestSocketFactory : public MockClientSocketFactory { TestSocketFactory() : fail_next_socket_(false) {} virtual ~TestSocketFactory() {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) OVERRIDE { if (fail_next_socket_) { fail_next_socket_ = false; - return new FailingUDPClientSocket(&empty_data_, net_log); + return scoped_ptr<DatagramClientSocket>( + new FailingUDPClientSocket(&empty_data_, net_log)); } SocketDataProvider* data_provider = mock_data().GetNext(); - TestUDPClientSocket* socket = new TestUDPClientSocket(this, - data_provider, - net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<TestUDPClientSocket> socket( + new TestUDPClientSocket(this, data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } void OnConnect(const IPEndPoint& endpoint) { diff --git a/chromium/net/dns/host_resolver.cc b/chromium/net/dns/host_resolver.cc index d74be91beb0..9b1e3314acd 100644 --- a/chromium/net/dns/host_resolver.cc +++ b/chromium/net/dns/host_resolver.cc @@ -98,9 +98,7 @@ HostResolver::RequestInfo::RequestInfo(const HostPortPair& host_port_pair) address_family_(ADDRESS_FAMILY_UNSPECIFIED), host_resolver_flags_(0), allow_cached_response_(true), - is_speculative_(false), - priority_(MEDIUM) { -} + is_speculative_(false) {} HostResolver::~HostResolver() { } diff --git a/chromium/net/dns/host_resolver.h b/chromium/net/dns/host_resolver.h index 558a1ddc4b2..2964fe600df 100644 --- a/chromium/net/dns/host_resolver.h +++ b/chromium/net/dns/host_resolver.h @@ -53,8 +53,8 @@ class NET_EXPORT HostResolver { bool enable_caching; }; - // The parameters for doing a Resolve(). A hostname and port are required, - // the rest are optional (and have reasonable defaults). + // The parameters for doing a Resolve(). A hostname and port are + // required; the rest are optional (and have reasonable defaults). class NET_EXPORT RequestInfo { public: explicit RequestInfo(const HostPortPair& host_port_pair); @@ -85,9 +85,6 @@ class NET_EXPORT HostResolver { bool is_speculative() const { return is_speculative_; } void set_is_speculative(bool b) { is_speculative_ = b; } - RequestPriority priority() const { return priority_; } - void set_priority(RequestPriority priority) { priority_ = priority; } - private: // The hostname to resolve, and the port to use in resulting sockaddrs. HostPortPair host_port_pair_; @@ -103,9 +100,6 @@ class NET_EXPORT HostResolver { // Whether this request was started by the DNS prefetcher. bool is_speculative_; - - // The priority for the request. - RequestPriority priority_; }; // Opaque type used to cancel a request. @@ -144,6 +138,7 @@ class NET_EXPORT HostResolver { // // Profiling information for the request is saved to |net_log| if non-NULL. virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, diff --git a/chromium/net/dns/host_resolver_impl.cc b/chromium/net/dns/host_resolver_impl.cc index e656f03955f..e4098d003be 100644 --- a/chromium/net/dns/host_resolver_impl.cc +++ b/chromium/net/dns/host_resolver_impl.cc @@ -69,10 +69,6 @@ const unsigned kNegativeCacheEntryTTLSeconds = 0; // Minimum TTL for successful resolutions with DnsTask. const unsigned kMinimumTTLSeconds = kCacheEntryTTLSeconds; -// Number of consecutive failures of DnsTask (with successful fallback) before -// the DnsClient is disabled until the next DNS change. -const unsigned kMaximumDnsFailures = 16; - // We use a separate histogram name for each platform to facilitate the // display of error codes by their symbolic name (since each platform has // different mappings). @@ -343,7 +339,6 @@ base::Value* NetLogRequestInfoCallback(const NetLog::Source& source, static_cast<int>(info->address_family())); dict->SetBoolean("allow_cached_response", info->allow_cached_response()); dict->SetBoolean("is_speculative", info->is_speculative()); - dict->SetInteger("priority", info->priority()); return dict; } @@ -457,6 +452,8 @@ class PriorityTracker { //----------------------------------------------------------------------------- +const unsigned HostResolverImpl::kMaximumDnsFailures = 16; + // Holds the data for a request that could not be completed synchronously. // It is owned by a Job. Canceled Requests are only marked as canceled rather // than removed from the Job's |requests_| list. @@ -465,16 +462,17 @@ class HostResolverImpl::Request { Request(const BoundNetLog& source_net_log, const BoundNetLog& request_net_log, const RequestInfo& info, + RequestPriority priority, const CompletionCallback& callback, AddressList* addresses) : source_net_log_(source_net_log), request_net_log_(request_net_log), info_(info), + priority_(priority), job_(NULL), callback_(callback), addresses_(addresses), - request_time_(base::TimeTicks::Now()) { - } + request_time_(base::TimeTicks::Now()) {} // Mark the request as canceled. void MarkAsCanceled() { @@ -521,16 +519,19 @@ class HostResolverImpl::Request { return info_; } - base::TimeTicks request_time() const { - return request_time_; - } + RequestPriority priority() const { return priority_; } + + base::TimeTicks request_time() const { return request_time_; } private: BoundNetLog source_net_log_; BoundNetLog request_net_log_; // The request info that started the request. - RequestInfo info_; + const RequestInfo info_; + + // TODO(akalin): Support reprioritization. + const RequestPriority priority_; // The resolve job that this request is dependent on. Job* job_; @@ -967,54 +968,98 @@ class HostResolverImpl::LoopbackProbeJob { // TODO(szym): This could be moved to separate source file as well. class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { public: - typedef base::Callback<void(int net_error, - const AddressList& addr_list, - base::TimeDelta ttl)> Callback; + class Delegate { + public: + virtual void OnDnsTaskComplete(base::TimeTicks start_time, + int net_error, + const AddressList& addr_list, + base::TimeDelta ttl) = 0; + + // Called when the first of two jobs succeeds. If the first completed + // transaction fails, this is not called. Also not called when the DnsTask + // only needs to run one transaction. + virtual void OnFirstDnsTransactionComplete() = 0; + + protected: + Delegate() {} + virtual ~Delegate() {} + }; DnsTask(DnsClient* client, const Key& key, - const Callback& callback, + Delegate* delegate, const BoundNetLog& job_net_log) : client_(client), - family_(key.address_family), - callback_(callback), - net_log_(job_net_log) { + key_(key), + delegate_(delegate), + net_log_(job_net_log), + num_completed_transactions_(0), + task_start_time_(base::TimeTicks::Now()) { DCHECK(client); - DCHECK(!callback.is_null()); - - // If unspecified, do IPv4 first, because suffix search will be faster. - uint16 qtype = (family_ == ADDRESS_FAMILY_IPV6) ? - dns_protocol::kTypeAAAA : - dns_protocol::kTypeA; - transaction_ = client_->GetTransactionFactory()->CreateTransaction( - key.hostname, - qtype, - base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), - true /* first_query */, base::TimeTicks::Now()), - net_log_); + DCHECK(delegate_); } - void Start() { + bool needs_two_transactions() const { + return key_.address_family == ADDRESS_FAMILY_UNSPECIFIED; + } + + bool needs_another_transaction() const { + return needs_two_transactions() && !transaction_aaaa_; + } + + void StartFirstTransaction() { + DCHECK_EQ(0u, num_completed_transactions_); net_log_.BeginEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK); - transaction_->Start(); + if (key_.address_family == ADDRESS_FAMILY_IPV6) { + StartAAAA(); + } else { + StartA(); + } + } + + void StartSecondTransaction() { + DCHECK(needs_two_transactions()); + StartAAAA(); } private: - void OnTransactionComplete(bool first_query, - const base::TimeTicks& start_time, + void StartA() { + DCHECK(!transaction_a_); + DCHECK_NE(ADDRESS_FAMILY_IPV6, key_.address_family); + transaction_a_ = CreateTransaction(ADDRESS_FAMILY_IPV4); + transaction_a_->Start(); + } + + void StartAAAA() { + DCHECK(!transaction_aaaa_); + DCHECK_NE(ADDRESS_FAMILY_IPV4, key_.address_family); + transaction_aaaa_ = CreateTransaction(ADDRESS_FAMILY_IPV6); + transaction_aaaa_->Start(); + } + + scoped_ptr<DnsTransaction> CreateTransaction(AddressFamily family) { + DCHECK_NE(ADDRESS_FAMILY_UNSPECIFIED, family); + return client_->GetTransactionFactory()->CreateTransaction( + key_.hostname, + family == ADDRESS_FAMILY_IPV6 ? dns_protocol::kTypeAAAA : + dns_protocol::kTypeA, + base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), + base::TimeTicks::Now()), + net_log_); + } + + void OnTransactionComplete(const base::TimeTicks& start_time, DnsTransaction* transaction, int net_error, const DnsResponse* response) { DCHECK(transaction); base::TimeDelta duration = base::TimeTicks::Now() - start_time; - // Run |callback_| last since the owning Job will then delete this DnsTask. if (net_error != OK) { DNS_HISTOGRAM("AsyncDNS.TransactionFailure", duration); OnFailure(net_error, DnsResponse::DNS_PARSE_OK); return; } - CHECK(response); DNS_HISTOGRAM("AsyncDNS.TransactionSuccess", duration); switch (transaction->GetType()) { case dns_protocol::kTypeA: @@ -1024,6 +1069,7 @@ class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { DNS_HISTOGRAM("AsyncDNS.TransactionSuccess_AAAA", duration); break; } + AddressList addr_list; base::TimeDelta ttl; DnsResponse::Result result = response->ParseToAddressList(&addr_list, &ttl); @@ -1036,58 +1082,53 @@ class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { return; } - bool needs_sort = false; - if (first_query) { - DCHECK(client_->GetConfig()) << - "Transaction should have been aborted when config changed!"; - if (family_ == ADDRESS_FAMILY_IPV6) { - needs_sort = (addr_list.size() > 1); - } else if (family_ == ADDRESS_FAMILY_UNSPECIFIED) { - first_addr_list_ = addr_list; - first_ttl_ = ttl; - // Use fully-qualified domain name to avoid search. - transaction_ = client_->GetTransactionFactory()->CreateTransaction( - response->GetDottedName() + ".", - dns_protocol::kTypeAAAA, - base::Bind(&DnsTask::OnTransactionComplete, base::Unretained(this), - false /* first_query */, base::TimeTicks::Now()), - net_log_); - transaction_->Start(); - return; - } + ++num_completed_transactions_; + if (num_completed_transactions_ == 1) { + ttl_ = ttl; } else { - DCHECK_EQ(ADDRESS_FAMILY_UNSPECIFIED, family_); - bool has_ipv6_addresses = !addr_list.empty(); - if (!first_addr_list_.empty()) { - ttl = std::min(ttl, first_ttl_); - // Place IPv4 addresses after IPv6. - addr_list.insert(addr_list.end(), first_addr_list_.begin(), - first_addr_list_.end()); - } - needs_sort = (has_ipv6_addresses && addr_list.size() > 1); + ttl_ = std::min(ttl_, ttl); } - if (addr_list.empty()) { + if (transaction->GetType() == dns_protocol::kTypeA) { + DCHECK_EQ(transaction_a_.get(), transaction); + // Place IPv4 addresses after IPv6. + addr_list_.insert(addr_list_.end(), addr_list.begin(), addr_list.end()); + } else { + DCHECK_EQ(transaction_aaaa_.get(), transaction); + // Place IPv6 addresses before IPv4. + addr_list_.insert(addr_list_.begin(), addr_list.begin(), addr_list.end()); + } + + if (needs_two_transactions() && num_completed_transactions_ == 1) { + // No need to repeat the suffix search. + key_.hostname = transaction->GetHostname(); + delegate_->OnFirstDnsTransactionComplete(); + return; + } + + if (addr_list_.empty()) { // TODO(szym): Don't fallback to ProcTask in this case. OnFailure(ERR_NAME_NOT_RESOLVED, DnsResponse::DNS_PARSE_OK); return; } - if (needs_sort) { - // Sort could complete synchronously. + // If there are multiple addresses, and at least one is IPv6, need to sort + // them. Note that IPv6 addresses are always put before IPv4 ones, so it's + // sufficient to just check the family of the first address. + if (addr_list_.size() > 1 && + addr_list_[0].GetFamily() == ADDRESS_FAMILY_IPV6) { + // Sort addresses if needed. Sort could complete synchronously. client_->GetAddressSorter()->Sort( - addr_list, + addr_list_, base::Bind(&DnsTask::OnSortComplete, AsWeakPtr(), - base::TimeTicks::Now(), - ttl)); + base::TimeTicks::Now())); } else { - OnSuccess(addr_list, ttl); + OnSuccess(addr_list_); } } void OnSortComplete(base::TimeTicks start_time, - base::TimeDelta ttl, bool success, const AddressList& addr_list) { if (!success) { @@ -1107,7 +1148,7 @@ class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { return; } - OnSuccess(addr_list, ttl); + OnSuccess(addr_list); } void OnFailure(int net_error, DnsResponse::Result result) { @@ -1115,26 +1156,34 @@ class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { net_log_.EndEvent( NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, base::Bind(&NetLogDnsTaskFailedCallback, net_error, result)); - callback_.Run(net_error, AddressList(), base::TimeDelta()); + delegate_->OnDnsTaskComplete(task_start_time_, net_error, AddressList(), + base::TimeDelta()); } - void OnSuccess(const AddressList& addr_list, base::TimeDelta ttl) { + void OnSuccess(const AddressList& addr_list) { net_log_.EndEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_DNS_TASK, addr_list.CreateNetLogCallback()); - callback_.Run(OK, addr_list, ttl); + delegate_->OnDnsTaskComplete(task_start_time_, OK, addr_list, ttl_); } DnsClient* client_; - AddressFamily family_; + Key key_; + // The listener to the results of this DnsTask. - Callback callback_; + Delegate* delegate_; const BoundNetLog net_log_; - scoped_ptr<DnsTransaction> transaction_; + scoped_ptr<DnsTransaction> transaction_a_; + scoped_ptr<DnsTransaction> transaction_aaaa_; + + unsigned num_completed_transactions_; - // Results from the first transaction. Used only if |family_| is unspecified. - AddressList first_addr_list_; - base::TimeDelta first_ttl_; + // These are updated as each transaction completes. + base::TimeDelta ttl_; + // IPv6 addresses must appear first in the list. + AddressList addr_list_; + + base::TimeTicks task_start_time_; DISALLOW_COPY_AND_ASSIGN(DnsTask); }; @@ -1142,7 +1191,8 @@ class HostResolverImpl::DnsTask : public base::SupportsWeakPtr<DnsTask> { //----------------------------------------------------------------------------- // Aggregates all Requests for the same Key. Dispatched via PriorityDispatch. -class HostResolverImpl::Job : public PrioritizedDispatcher::Job { +class HostResolverImpl::Job : public PrioritizedDispatcher::Job, + public HostResolverImpl::DnsTask::Delegate { public: // Creates new job for |key| where |request_net_log| is bound to the // request that spawned it. @@ -1155,6 +1205,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { priority_tracker_(priority), had_non_speculative_request_(false), had_dns_config_(false), + num_occupied_job_slots_(0), dns_task_error_(OK), creation_time_(base::TimeTicks::Now()), priority_change_time_(creation_time_), @@ -1178,7 +1229,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { proc_task_ = NULL; } // Clean up now for nice NetLog. - dns_task_.reset(NULL); + KillDnsTask(); net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB, ERR_ABORTED); } else if (is_queued()) { @@ -1201,16 +1252,30 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { } } - // Add this job to the dispatcher. - void Schedule() { - handle_ = resolver_->dispatcher_.Add(this, priority()); + // Add this job to the dispatcher. If "at_head" is true, adds at the front + // of the queue. + void Schedule(bool at_head) { + DCHECK(!is_queued()); + PrioritizedDispatcher::Handle handle; + if (!at_head) { + handle = resolver_->dispatcher_.Add(this, priority()); + } else { + handle = resolver_->dispatcher_.AddAtHead(this, priority()); + } + // The dispatcher could have started |this| in the above call to Add, which + // could have called Schedule again. In that case |handle| will be null, + // but |handle_| may have been set by the other nested call to Schedule. + if (!handle.is_null()) { + DCHECK(handle_.is_null()); + handle_ = handle; + } } void AddRequest(scoped_ptr<Request> req) { DCHECK_EQ(key_.hostname, req->info().hostname()); req->set_job(this); - priority_tracker_.Add(req->info().priority()); + priority_tracker_.Add(req->priority()); req->request_net_log().AddEvent( NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_ATTACH, @@ -1245,12 +1310,11 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { LogCancelRequest(req->source_net_log(), req->request_net_log(), req->info()); - priority_tracker_.Remove(req->info().priority()); - net_log_.AddEvent( - NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_REQUEST_DETACH, - base::Bind(&NetLogJobAttachCallback, - req->request_net_log().source(), - priority())); + priority_tracker_.Remove(req->priority()); + net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_REQUEST_DETACH, + base::Bind(&NetLogJobAttachCallback, + req->request_net_log().source(), + priority())); if (num_active_requests() > 0) { UpdatePriority(); @@ -1272,7 +1336,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { // If DnsTask present, abort it and fall back to ProcTask. void AbortDnsTask() { if (dns_task_) { - dns_task_.reset(); + KillDnsTask(); dns_task_error_ = OK; StartProcTask(); } @@ -1321,6 +1385,29 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { } private: + void KillDnsTask() { + if (dns_task_) { + ReduceToOneJobSlot(); + dns_task_.reset(); + } + } + + // Reduce the number of job slots occupied and queued in the dispatcher + // to one. If the second Job slot is queued in the dispatcher, cancels the + // queued job. Otherwise, the second Job has been started by the + // PrioritizedDispatcher, so signals it is complete. + void ReduceToOneJobSlot() { + DCHECK_GE(num_occupied_job_slots_, 1u); + if (is_queued()) { + resolver_->dispatcher_.Cancel(handle_); + handle_.Reset(); + } else if (num_occupied_job_slots_ > 1) { + resolver_->dispatcher_.OnJobFinished(); + --num_occupied_job_slots_; + } + DCHECK_EQ(1u, num_occupied_job_slots_); + } + void UpdatePriority() { if (is_queued()) { if (priority() != static_cast<RequestPriority>(handle_.priority())) @@ -1337,8 +1424,17 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { // PriorityDispatch::Job: virtual void Start() OVERRIDE { - DCHECK(!is_running()); + DCHECK_LE(num_occupied_job_slots_, 1u); + handle_.Reset(); + ++num_occupied_job_slots_; + + if (num_occupied_job_slots_ == 2) { + StartSecondDnsTransaction(); + return; + } + + DCHECK(!is_running()); net_log_.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_JOB_STARTED); @@ -1359,8 +1455,12 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { queue_time_after_change); } + bool system_only = + (key_.host_resolver_flags & HOST_RESOLVER_SYSTEM_ONLY) != 0; + // Caution: Job::Start must not complete synchronously. - if (had_dns_config_ && !ResemblesMulticastDNSName(key_.hostname)) { + if (!system_only && had_dns_config_ && + !ResemblesMulticastDNSName(key_.hostname)) { StartDnsTask(); } else { StartProcTask(); @@ -1442,14 +1542,18 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { void StartDnsTask() { DCHECK(resolver_->HaveDnsConfig()); - base::TimeTicks start_time = base::TimeTicks::Now(); - dns_task_.reset(new DnsTask( - resolver_->dns_client_.get(), - key_, - base::Bind(&Job::OnDnsTaskComplete, base::Unretained(this), start_time), - net_log_)); + dns_task_.reset(new DnsTask(resolver_->dns_client_.get(), key_, this, + net_log_)); - dns_task_->Start(); + dns_task_->StartFirstTransaction(); + // Schedule a second transaction, if needed. + if (dns_task_->needs_two_transactions()) + Schedule(true); + } + + void StartSecondDnsTransaction() { + DCHECK(dns_task_->needs_two_transactions()); + dns_task_->StartSecondTransaction(); } // Called if DnsTask fails. It is posted from StartDnsTask, so Job may be @@ -1471,7 +1575,7 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { // TODO(szym): Some net errors indicate lack of connectivity. Starting // ProcTask in that case is a waste of time. if (resolver_->fallback_to_proctask_) { - dns_task_.reset(); + KillDnsTask(); StartProcTask(); } else { UmaAsyncDnsResolveStatus(RESOLVE_STATUS_FAIL); @@ -1479,11 +1583,13 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { } } - // Called by DnsTask when it completes. - void OnDnsTaskComplete(base::TimeTicks start_time, - int net_error, - const AddressList& addr_list, - base::TimeDelta ttl) { + + // HostResolverImpl::DnsTask::Delegate implementation: + + virtual void OnDnsTaskComplete(base::TimeTicks start_time, + int net_error, + const AddressList& addr_list, + base::TimeDelta ttl) OVERRIDE { DCHECK(is_dns_running()); base::TimeDelta duration = base::TimeTicks::Now() - start_time; @@ -1518,6 +1624,19 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { bounded_ttl); } + virtual void OnFirstDnsTransactionComplete() OVERRIDE { + DCHECK(dns_task_->needs_two_transactions()); + DCHECK_EQ(dns_task_->needs_another_transaction(), is_queued()); + // No longer need to occupy two dispatcher slots. + ReduceToOneJobSlot(); + + // We already have a job slot at the dispatcher, so if the second + // transaction hasn't started, reuse it now instead of waiting in the queue + // for the second slot. + if (dns_task_->needs_another_transaction()) + dns_task_->StartSecondTransaction(); + } + // Performs Job's last rites. Completes all Requests. Deletes this. void CompleteRequests(const HostCache::Entry& entry, base::TimeDelta ttl) { @@ -1532,12 +1651,12 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { resolver_->RemoveJob(this); if (is_running()) { - DCHECK(!is_queued()); if (is_proc_running()) { + DCHECK(!is_queued()); proc_task_->Cancel(); proc_task_ = NULL; } - dns_task_.reset(); + KillDnsTask(); // Signal dispatcher that a slot has opened. resolver_->dispatcher_.OnJobFinished(); @@ -1631,6 +1750,9 @@ class HostResolverImpl::Job : public PrioritizedDispatcher::Job { // Distinguishes measurements taken while DnsClient was fully configured. bool had_dns_config_; + // Number of slots occupied by this Job in resolver's PrioritizedDispatcher. + unsigned num_occupied_job_slots_; + // Result of DnsTask. int dns_task_error_; @@ -1681,6 +1803,7 @@ HostResolverImpl::HostResolverImpl( received_dns_config_(false), num_dns_failures_(0), probe_ipv6_support_(true), + use_local_ipv6_(false), resolved_known_ipv6_hostname_(false), additional_resolver_flags_(0), fallback_to_proctask_(true) { @@ -1706,19 +1829,22 @@ HostResolverImpl::HostResolverImpl( EnsureDnsReloaderInit(); #endif - // TODO(szym): Remove when received_dns_config_ is removed, once - // http://crbug.com/137914 is resolved. { DnsConfig dns_config; NetworkChangeNotifier::GetDnsConfig(&dns_config); received_dns_config_ = dns_config.IsValid(); + // Conservatively assume local IPv6 is needed when DnsConfig is not valid. + use_local_ipv6_ = !dns_config.IsValid() || dns_config.use_local_ipv6; } fallback_to_proctask_ = !ConfigureAsyncDnsNoFallbackFieldTrial(); } HostResolverImpl::~HostResolverImpl() { - // This will also cancel all outstanding requests. + // Prevent the dispatcher from starting new jobs. + dispatcher_.SetLimitsToZero(); + // It's now safe for Jobs to call KillDsnTask on destruction, because + // OnJobComplete will not start any new jobs. STLDeleteValues(&jobs_); NetworkChangeNotifier::RemoveIPAddressObserver(this); @@ -1732,6 +1858,7 @@ void HostResolverImpl::SetMaxQueuedJobs(size_t value) { } int HostResolverImpl::Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -1768,9 +1895,9 @@ int HostResolverImpl::Resolve(const RequestInfo& info, JobMap::iterator jobit = jobs_.find(key); Job* job; if (jobit == jobs_.end()) { - job = new Job(weak_ptr_factory_.GetWeakPtr(), key, info.priority(), - request_net_log); - job->Schedule(); + job = + new Job(weak_ptr_factory_.GetWeakPtr(), key, priority, request_net_log); + job->Schedule(false); // Check for queue overflow. if (dispatcher_.num_queued_jobs() > max_queued_jobs_) { @@ -1789,11 +1916,8 @@ int HostResolverImpl::Resolve(const RequestInfo& info, } // Can't complete synchronously. Create and attach request. - scoped_ptr<Request> req(new Request(source_net_log, - request_net_log, - info, - callback, - addresses)); + scoped_ptr<Request> req(new Request( + source_net_log, request_net_log, info, priority, callback, addresses)); if (out_req) *out_req = reinterpret_cast<RequestHandle>(req.get()); @@ -2023,35 +2147,35 @@ HostResolverImpl::Key HostResolverImpl::GetEffectiveKeyForRequest( AddressFamily effective_address_family = info.address_family(); if (info.address_family() == ADDRESS_FAMILY_UNSPECIFIED) { - base::TimeTicks start_time = base::TimeTicks::Now(); - // Google DNS address. - const uint8 kIPv6Address[] = - { 0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88 }; - IPAddressNumber address(kIPv6Address, - kIPv6Address + arraysize(kIPv6Address)); - bool rv6 = IsGloballyReachable(address, net_log); - if (rv6) - net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_IPV6_SUPPORTED); - - UMA_HISTOGRAM_TIMES("Net.IPv6ConnectDuration", - base::TimeTicks::Now() - start_time); - if (rv6) { - UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectSuccessMatch", - default_address_family_ == ADDRESS_FAMILY_UNSPECIFIED); + if (probe_ipv6_support_ && !use_local_ipv6_) { + base::TimeTicks start_time = base::TimeTicks::Now(); + // Google DNS address. + const uint8 kIPv6Address[] = + { 0x20, 0x01, 0x48, 0x60, 0x48, 0x60, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x88, 0x88 }; + IPAddressNumber address(kIPv6Address, + kIPv6Address + arraysize(kIPv6Address)); + bool rv6 = IsGloballyReachable(address, net_log); + if (rv6) + net_log.AddEvent(NetLog::TYPE_HOST_RESOLVER_IMPL_IPV6_SUPPORTED); + + UMA_HISTOGRAM_TIMES("Net.IPv6ConnectDuration", + base::TimeTicks::Now() - start_time); + if (rv6) { + UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectSuccessMatch", + default_address_family_ == ADDRESS_FAMILY_UNSPECIFIED); + } else { + UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectFailureMatch", + default_address_family_ != ADDRESS_FAMILY_UNSPECIFIED); + + effective_address_family = ADDRESS_FAMILY_IPV4; + effective_flags |= HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; + } } else { - UMA_HISTOGRAM_BOOLEAN("Net.IPv6ConnectFailureMatch", - default_address_family_ != ADDRESS_FAMILY_UNSPECIFIED); + effective_address_family = default_address_family_; } } - if (effective_address_family == ADDRESS_FAMILY_UNSPECIFIED && - default_address_family_ != ADDRESS_FAMILY_UNSPECIFIED) { - effective_address_family = default_address_family_; - if (probe_ipv6_support_) - effective_flags |= HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6; - } - return Key(info.hostname(), effective_address_family, effective_flags); } @@ -2070,8 +2194,13 @@ void HostResolverImpl::AbortAllInProgressJobs() { } } - // Check if no dispatcher slots leaked out. - DCHECK_EQ(dispatcher_.num_running_jobs(), jobs_to_abort.size()); + // Pause the dispatcher so it won't start any new dispatcher jobs while + // aborting the old ones. This is needed so that it won't start the second + // DnsTransaction for a job in |jobs_to_abort| if the DnsConfig just became + // invalid. + PrioritizedDispatcher::Limits limits = dispatcher_.GetLimits(); + dispatcher_.SetLimits( + PrioritizedDispatcher::Limits(limits.reserved_slots.size(), 0)); // Life check to bail once |this| is deleted. base::WeakPtr<HostResolverImpl> self = weak_ptr_factory_.GetWeakPtr(); @@ -2081,6 +2210,22 @@ void HostResolverImpl::AbortAllInProgressJobs() { jobs_to_abort[i]->Abort(); jobs_to_abort[i] = NULL; } + + if (self) + dispatcher_.SetLimits(limits); +} + +void HostResolverImpl::AbortDnsTasks() { + // Pause the dispatcher so it won't start any new dispatcher jobs while + // aborting the old ones. This is needed so that it won't start the second + // DnsTransaction for a job if the DnsConfig just changed. + PrioritizedDispatcher::Limits limits = dispatcher_.GetLimits(); + dispatcher_.SetLimits( + PrioritizedDispatcher::Limits(limits.reserved_slots.size(), 0)); + + for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) + it->second->AbortDnsTask(); + dispatcher_.SetLimits(limits); } void HostResolverImpl::TryServingAllJobsFromHosts() { @@ -2126,6 +2271,8 @@ void HostResolverImpl::OnDNSChanged() { // TODO(szym): Remove once http://crbug.com/137914 is resolved. received_dns_config_ = dns_config.IsValid(); + // Conservatively assume local IPv6 is needed when DnsConfig is not valid. + use_local_ipv6_ = !dns_config.IsValid() || dns_config.use_local_ipv6; num_dns_failures_ = 0; @@ -2133,7 +2280,7 @@ void HostResolverImpl::OnDNSChanged() { // the newly started jobs use the new config. if (dns_client_.get()) { dns_client_->SetConfig(dns_config); - if (dns_config.IsValid()) + if (dns_client_->GetConfig()) UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); } @@ -2175,10 +2322,14 @@ void HostResolverImpl::OnDnsTaskResolve(int net_error) { ++num_dns_failures_; if (num_dns_failures_ < kMaximumDnsFailures) return; - // Disable DnsClient until the next DNS change. - for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) - it->second->AbortDnsTask(); + + // Disable DnsClient until the next DNS change. Must be done before aborting + // DnsTasks, since doing so may start new jobs. dns_client_->SetConfig(DnsConfig()); + + // Switch jobs with active DnsTasks over to using ProcTasks. + AbortDnsTasks(); + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", false); UMA_HISTOGRAM_CUSTOM_ENUMERATION("AsyncDNS.DnsClientDisabledReason", std::abs(net_error), @@ -2186,21 +2337,20 @@ void HostResolverImpl::OnDnsTaskResolve(int net_error) { } void HostResolverImpl::SetDnsClient(scoped_ptr<DnsClient> dns_client) { - if (HaveDnsConfig()) { - for (JobMap::iterator it = jobs_.begin(); it != jobs_.end(); ++it) - it->second->AbortDnsTask(); - } + // DnsClient and config must be updated before aborting DnsTasks, since doing + // so may start new jobs. dns_client_ = dns_client.Pass(); - if (!dns_client_ || dns_client_->GetConfig() || - num_dns_failures_ >= kMaximumDnsFailures) { - return; + if (dns_client_ && !dns_client_->GetConfig() && + num_dns_failures_ < kMaximumDnsFailures) { + DnsConfig dns_config; + NetworkChangeNotifier::GetDnsConfig(&dns_config); + dns_client_->SetConfig(dns_config); + num_dns_failures_ = 0; + if (dns_client_->GetConfig()) + UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); } - DnsConfig dns_config; - NetworkChangeNotifier::GetDnsConfig(&dns_config); - dns_client_->SetConfig(dns_config); - num_dns_failures_ = 0; - if (dns_config.IsValid()) - UMA_HISTOGRAM_BOOLEAN("AsyncDNS.DnsClientEnabled", true); + + AbortDnsTasks(); } } // namespace net diff --git a/chromium/net/dns/host_resolver_impl.h b/chromium/net/dns/host_resolver_impl.h index 928d07af8b8..e7f5f6541ca 100644 --- a/chromium/net/dns/host_resolver_impl.h +++ b/chromium/net/dns/host_resolver_impl.h @@ -131,6 +131,7 @@ class NET_EXPORT HostResolverImpl // HostResolver methods: virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -156,6 +157,10 @@ class NET_EXPORT HostResolverImpl typedef std::map<Key, Job*> JobMap; typedef ScopedVector<Request> RequestsList; + // Number of consecutive failures of DnsTask (with successful fallback to + // ProcTask) before the DnsClient is disabled until the next DNS change. + static const unsigned kMaximumDnsFailures; + // Helper used by |Resolve()| and |ResolveFromCache()|. Performs IP // literal, cache and HOSTS lookup (if enabled), returns OK if successful, // ERR_NAME_NOT_RESOLVED if either hostname is invalid or IP literal is @@ -207,6 +212,11 @@ class NET_EXPORT HostResolverImpl // requests. Might start new jobs. void AbortAllInProgressJobs(); + // Aborts all in progress DnsTasks. In-progress jobs will fall back to + // ProcTasks. Might start new jobs, if any jobs were taking up two dispatcher + // slots. + void AbortDnsTasks(); + // Attempts to serve each Job in |jobs_| from the HOSTS file if we have // a DnsClient with a valid DnsConfig. void TryServingAllJobsFromHosts(); @@ -224,8 +234,10 @@ class NET_EXPORT HostResolverImpl // and resulted in |net_error|. void OnDnsTaskResolve(int net_error); - // Allows the tests to catch slots leaking out of the dispatcher. - size_t num_running_jobs_for_tests() const { + // Allows the tests to catch slots leaking out of the dispatcher. One + // HostResolverImpl::Job could occupy multiple PrioritizedDispatcher job + // slots. + size_t num_running_dispatcher_jobs_for_tests() const { return dispatcher_.num_running_jobs(); } @@ -267,6 +279,10 @@ class NET_EXPORT HostResolverImpl // explicit setting in |default_address_family_| is used. bool probe_ipv6_support_; + // True if DnsConfigService detected that system configuration depends on + // local IPv6 connectivity. Disables probing. + bool use_local_ipv6_; + // True iff ProcTask has successfully resolved a hostname known to have IPv6 // addresses using ADDRESS_FAMILY_UNSPECIFIED. Reset on IP address change. bool resolved_known_ipv6_hostname_; diff --git a/chromium/net/dns/host_resolver_impl_unittest.cc b/chromium/net/dns/host_resolver_impl_unittest.cc index f6b7f690675..c314f80c24d 100644 --- a/chromium/net/dns/host_resolver_impl_unittest.cc +++ b/chromium/net/dns/host_resolver_impl_unittest.cc @@ -12,6 +12,7 @@ #include "base/memory/ref_counted.h" #include "base/memory/scoped_vector.h" #include "base/message_loop/message_loop.h" +#include "base/run_loop.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" #include "base/synchronization/condition_variable.h" @@ -197,10 +198,12 @@ class Request { }; Request(const HostResolver::RequestInfo& info, + RequestPriority priority, size_t index, HostResolver* resolver, Handler* handler) : info_(info), + priority_(priority), index_(index), resolver_(resolver), handler_(handler), @@ -213,8 +216,12 @@ class Request { DCHECK(!handle_); list_ = AddressList(); result_ = resolver_->Resolve( - info_, &list_, base::Bind(&Request::OnComplete, base::Unretained(this)), - &handle_, BoundNetLog()); + info_, + priority_, + &list_, + base::Bind(&Request::OnComplete, base::Unretained(this)), + &handle_, + BoundNetLog()); if (!list_.empty()) EXPECT_EQ(OK, result_); return result_; @@ -291,6 +298,7 @@ class Request { } HostResolver::RequestInfo info_; + RequestPriority priority_; size_t index_; HostResolver* resolver_; Handler* handler_; @@ -413,14 +421,29 @@ class HostResolverImplTest : public testing::Test { HostResolverImplTest() : proc_(new MockHostResolverProc()) {} + void CreateResolver() { + CreateResolverWithLimitsAndParams(DefaultLimits(), + DefaultParams(proc_.get())); + } + + // This HostResolverImpl will only allow 1 outstanding resolve at a time and + // perform no retries. + void CreateSerialResolver() { + HostResolverImpl::ProcTaskParams params = DefaultParams(proc_.get()); + params.max_retry_attempts = 0u; + PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); + CreateResolverWithLimitsAndParams(limits, params); + } + protected: // A Request::Handler which is a proxy to the HostResolverImplTest fixture. struct Handler : public Request::Handler { virtual ~Handler() {} // Proxy functions so that classes derived from Handler can access them. - Request* CreateRequest(const HostResolver::RequestInfo& info) { - return test->CreateRequest(info); + Request* CreateRequest(const HostResolver::RequestInfo& info, + RequestPriority priority) { + return test->CreateRequest(info, priority); } Request* CreateRequest(const std::string& hostname, int port) { return test->CreateRequest(hostname, port); @@ -435,31 +458,30 @@ class HostResolverImplTest : public testing::Test { HostResolverImplTest* test; }; - void CreateResolver() { - resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(), - DefaultLimits(), - DefaultParams(proc_.get()), - NULL)); + // testing::Test implementation: + virtual void SetUp() OVERRIDE { + CreateResolver(); } - // This HostResolverImpl will only allow 1 outstanding resolve at a time and - // perform no retries. - void CreateSerialResolver() { - HostResolverImpl::ProcTaskParams params = DefaultParams(proc_.get()); - params.max_retry_attempts = 0u; - PrioritizedDispatcher::Limits limits(NUM_PRIORITIES, 1); - resolver_.reset(new HostResolverImpl( - HostCache::CreateDefaultCache(), - limits, - params, - NULL)); + virtual void TearDown() OVERRIDE { + if (resolver_.get()) + EXPECT_EQ(0u, resolver_->num_running_dispatcher_jobs_for_tests()); + EXPECT_FALSE(proc_->HasBlockedRequests()); + } + + virtual void CreateResolverWithLimitsAndParams( + const PrioritizedDispatcher::Limits& limits, + const HostResolverImpl::ProcTaskParams& params) { + resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(), + limits, params, NULL)); } // The Request will not be made until a call to |Resolve()|, and the Job will // not start until released by |proc_->SignalXXX|. - Request* CreateRequest(const HostResolver::RequestInfo& info) { - Request* req = new Request(info, requests_.size(), resolver_.get(), - handler_.get()); + Request* CreateRequest(const HostResolver::RequestInfo& info, + RequestPriority priority) { + Request* req = new Request( + info, priority, requests_.size(), resolver_.get(), handler_.get()); requests_.push_back(req); return req; } @@ -469,9 +491,8 @@ class HostResolverImplTest : public testing::Test { RequestPriority priority, AddressFamily family) { HostResolver::RequestInfo info(HostPortPair(hostname, port)); - info.set_priority(priority); info.set_address_family(family); - return CreateRequest(info); + return CreateRequest(info, priority); } Request* CreateRequest(const std::string& hostname, @@ -488,25 +509,15 @@ class HostResolverImplTest : public testing::Test { return CreateRequest(hostname, kDefaultPort); } - virtual void SetUp() OVERRIDE { - CreateResolver(); - } - - virtual void TearDown() OVERRIDE { - if (resolver_.get()) - EXPECT_EQ(0u, resolver_->num_running_jobs_for_tests()); - EXPECT_FALSE(proc_->HasBlockedRequests()); - } - void set_handler(Handler* handler) { handler_.reset(handler); handler_->test = this; } // Friendship is not inherited, so use proxies to access those. - size_t num_running_jobs() const { + size_t num_running_dispatcher_jobs() const { DCHECK(resolver_.get()); - return resolver_->num_running_jobs_for_tests(); + return resolver_->num_running_dispatcher_jobs_for_tests(); } void set_fallback_to_proctask(bool fallback_to_proctask) { @@ -514,6 +525,10 @@ class HostResolverImplTest : public testing::Test { resolver_->fallback_to_proctask_ = fallback_to_proctask; } + static unsigned maximum_dns_failures() { + return HostResolverImpl::kMaximumDnsFailures; + } + scoped_refptr<MockHostResolverProc> proc_; scoped_ptr<HostResolverImpl> resolver_; ScopedVector<Request> requests_; @@ -817,7 +832,8 @@ TEST_F(HostResolverImplTest, BypassCache) { // longer service the request synchronously. HostResolver::RequestInfo info(HostPortPair(hostname, 71)); info.set_allow_cached_response(false); - EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, + CreateRequest(info, DEFAULT_PRIORITY)->Resolve()); } else if (71 == req->info().port()) { // Test is done. base::MessageLoop::current()->Quit(); @@ -889,7 +905,7 @@ TEST_F(HostResolverImplTest, ObeyPoolConstraintsAfterIPAddressChange) { EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->WaitForResult()); - EXPECT_EQ(1u, num_running_jobs()); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); EXPECT_FALSE(requests_[1]->completed()); EXPECT_FALSE(requests_[2]->completed()); @@ -1189,14 +1205,15 @@ TEST_F(HostResolverImplTest, ResolveFromCache) { HostResolver::RequestInfo info(HostPortPair("just.testing", 80)); // First hit will miss the cache. - EXPECT_EQ(ERR_DNS_CACHE_MISS, CreateRequest(info)->ResolveFromCache()); + EXPECT_EQ(ERR_DNS_CACHE_MISS, + CreateRequest(info, DEFAULT_PRIORITY)->ResolveFromCache()); // This time, we fetch normally. - EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info, DEFAULT_PRIORITY)->Resolve()); EXPECT_EQ(OK, requests_[1]->WaitForResult()); // Now we should be able to fetch from the cache. - EXPECT_EQ(OK, CreateRequest(info)->ResolveFromCache()); + EXPECT_EQ(OK, CreateRequest(info, DEFAULT_PRIORITY)->ResolveFromCache()); EXPECT_TRUE(requests_[2]->HasOneAddress("192.168.1.42", 80)); } @@ -1228,7 +1245,7 @@ TEST_F(HostResolverImplTest, MultipleAttempts) { // Resolve "host1". HostResolver::RequestInfo info(HostPortPair("host1", 70)); - Request* req = CreateRequest(info); + Request* req = CreateRequest(info, DEFAULT_PRIORITY); EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); // Resolve returns -4 to indicate that 3rd attempt has resolved the host. @@ -1255,36 +1272,70 @@ DnsConfig CreateValidDnsConfig() { // Specialized fixture for tests of DnsTask. class HostResolverImplDnsTest : public HostResolverImplTest { + public: + HostResolverImplDnsTest() : dns_client_(NULL) {} + protected: + // testing::Test implementation: virtual void SetUp() OVERRIDE { - AddDnsRule("nx", dns_protocol::kTypeA, MockDnsClientRule::FAIL); - AddDnsRule("nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL); - AddDnsRule("ok", dns_protocol::kTypeA, MockDnsClientRule::OK); - AddDnsRule("ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK); - AddDnsRule("4ok", dns_protocol::kTypeA, MockDnsClientRule::OK); - AddDnsRule("4ok", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY); - AddDnsRule("6ok", dns_protocol::kTypeA, MockDnsClientRule::EMPTY); - AddDnsRule("6ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK); - AddDnsRule("4nx", dns_protocol::kTypeA, MockDnsClientRule::OK); - AddDnsRule("4nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL); + AddDnsRule("nx", dns_protocol::kTypeA, MockDnsClientRule::FAIL, false); + AddDnsRule("nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL, false); + AddDnsRule("ok", dns_protocol::kTypeA, MockDnsClientRule::OK, false); + AddDnsRule("ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK, false); + AddDnsRule("4ok", dns_protocol::kTypeA, MockDnsClientRule::OK, false); + AddDnsRule("4ok", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY, false); + AddDnsRule("6ok", dns_protocol::kTypeA, MockDnsClientRule::EMPTY, false); + AddDnsRule("6ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK, false); + AddDnsRule("4nx", dns_protocol::kTypeA, MockDnsClientRule::OK, false); + AddDnsRule("4nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL, false); + AddDnsRule("empty", dns_protocol::kTypeA, MockDnsClientRule::EMPTY, false); + AddDnsRule("empty", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY, + false); + + AddDnsRule("slow_nx", dns_protocol::kTypeA, MockDnsClientRule::FAIL, true); + AddDnsRule("slow_nx", dns_protocol::kTypeAAAA, MockDnsClientRule::FAIL, + true); + + AddDnsRule("4slow_ok", dns_protocol::kTypeA, MockDnsClientRule::OK, true); + AddDnsRule("4slow_ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK, + false); + AddDnsRule("6slow_ok", dns_protocol::kTypeA, MockDnsClientRule::OK, false); + AddDnsRule("6slow_ok", dns_protocol::kTypeAAAA, MockDnsClientRule::OK, + true); + AddDnsRule("4slow_4ok", dns_protocol::kTypeA, MockDnsClientRule::OK, true); + AddDnsRule("4slow_4ok", dns_protocol::kTypeAAAA, MockDnsClientRule::EMPTY, + false); + AddDnsRule("4slow_4timeout", dns_protocol::kTypeA, + MockDnsClientRule::TIMEOUT, true); + AddDnsRule("4slow_4timeout", dns_protocol::kTypeAAAA, MockDnsClientRule::OK, + false); + AddDnsRule("4slow_6timeout", dns_protocol::kTypeA, + MockDnsClientRule::OK, true); + AddDnsRule("4slow_6timeout", dns_protocol::kTypeAAAA, + MockDnsClientRule::TIMEOUT, false); CreateResolver(); } - void CreateResolver() { + // HostResolverImplTest implementation: + virtual void CreateResolverWithLimitsAndParams( + const PrioritizedDispatcher::Limits& limits, + const HostResolverImpl::ProcTaskParams& params) OVERRIDE { resolver_.reset(new HostResolverImpl(HostCache::CreateDefaultCache(), - DefaultLimits(), - DefaultParams(proc_.get()), + limits, + params, NULL)); // Disable IPv6 support probing. resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); - resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_)); + dns_client_ = new MockDnsClient(DnsConfig(), dns_rules_); + resolver_->SetDnsClient(scoped_ptr<DnsClient>(dns_client_)); } // Adds a rule to |dns_rules_|. Must be followed by |CreateResolver| to apply. void AddDnsRule(const std::string& prefix, uint16 qtype, - MockDnsClientRule::Result result) { - dns_rules_.push_back(MockDnsClientRule(prefix, qtype, result)); + MockDnsClientRule::Result result, + bool delay) { + dns_rules_.push_back(MockDnsClientRule(prefix, qtype, result, delay)); } void ChangeDnsConfig(const DnsConfig& config) { @@ -1294,6 +1345,8 @@ class HostResolverImplDnsTest : public HostResolverImplTest { } MockDnsClientRuleList dns_rules_; + // Owned by |resolver_|. + MockDnsClient* dns_client_; }; // TODO(szym): Test AbortAllInProgressJobs due to DnsConfig change. @@ -1361,7 +1414,8 @@ TEST_F(HostResolverImplDnsTest, NoFallbackToProcTask) { // Simulate the case when the preference or policy has disabled the DNS client // causing AbortDnsTasks. - resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_)); + resolver_->SetDnsClient( + scoped_ptr<DnsClient>(new MockDnsClient(DnsConfig(), dns_rules_))); ChangeDnsConfig(CreateValidDnsConfig()); // First request is resolved by MockDnsClient, others should fail due to @@ -1509,6 +1563,24 @@ TEST_F(HostResolverImplDnsTest, BypassDnsTask) { EXPECT_EQ(OK, requests_[i]->WaitForResult()) << i; } +TEST_F(HostResolverImplDnsTest, SystemOnlyBypassesDnsTask) { + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies(std::string(), std::string()); + + HostResolver::RequestInfo info_bypass(HostPortPair("ok", 80)); + info_bypass.set_host_resolver_flags(HOST_RESOLVER_SYSTEM_ONLY); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info_bypass, MEDIUM)->Resolve()); + + HostResolver::RequestInfo info(HostPortPair("ok", 80)); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(info, MEDIUM)->Resolve()); + + proc_->SignalMultiple(requests_.size()); + + EXPECT_EQ(ERR_NAME_NOT_RESOLVED, requests_[0]->WaitForResult()); + EXPECT_EQ(OK, requests_[1]->WaitForResult()); +} + TEST_F(HostResolverImplDnsTest, DisableDnsClientOnPersistentFailure) { ChangeDnsConfig(CreateValidDnsConfig()); @@ -1520,7 +1592,7 @@ TEST_F(HostResolverImplDnsTest, DisableDnsClientOnPersistentFailure) { EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); EXPECT_EQ(OK, req->WaitForResult()); - for (unsigned i = 0; i < 20; ++i) { + for (unsigned i = 0; i < maximum_dns_failures(); ++i) { // Use custom names to require separate Jobs. std::string hostname = base::StringPrintf("nx_%u", i); // Ensure fallback to ProcTask succeeds. @@ -1585,7 +1657,8 @@ TEST_F(HostResolverImplDnsTest, DualFamilyLocalhost) { DefaultLimits(), DefaultParams(proc.get()), NULL)); - resolver_->SetDnsClient(CreateMockDnsClient(DnsConfig(), dns_rules_)); + resolver_->SetDnsClient( + scoped_ptr<DnsClient>(new MockDnsClient(DnsConfig(), dns_rules_))); resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); // Get the expected output. @@ -1609,7 +1682,7 @@ TEST_F(HostResolverImplDnsTest, DualFamilyLocalhost) { // Try without DnsClient. ChangeDnsConfig(DnsConfig()); - Request* req = CreateRequest(info); + Request* req = CreateRequest(info, DEFAULT_PRIORITY); // It is resolved via getaddrinfo, so expect asynchronous result. EXPECT_EQ(ERR_IO_PENDING, req->Resolve()); EXPECT_EQ(OK, req->WaitForResult()); @@ -1630,7 +1703,7 @@ TEST_F(HostResolverImplDnsTest, DualFamilyLocalhost) { config.hosts = hosts; ChangeDnsConfig(config); - req = CreateRequest(info); + req = CreateRequest(info, DEFAULT_PRIORITY); // Expect synchronous resolution from DnsHosts. EXPECT_EQ(OK, req->Resolve()); @@ -1638,4 +1711,346 @@ TEST_F(HostResolverImplDnsTest, DualFamilyLocalhost) { EXPECT_EQ(saw_ipv6, req->HasAddress("::1", 80)); } +// Cancel a request with a single DNS transaction active. +TEST_F(HostResolverImplDnsTest, CancelWithOneTransactionActive) { + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_IPV4); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); + requests_[0]->Cancel(); + + // Dispatcher state checked in TearDown. +} + +// Cancel a request with a single DNS transaction active and another pending. +TEST_F(HostResolverImplDnsTest, CancelWithOneTransactionActiveOnePending) { + CreateSerialResolver(); + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); + requests_[0]->Cancel(); + + // Dispatcher state checked in TearDown. +} + +// Cancel a request with two DNS transactions active. +TEST_F(HostResolverImplDnsTest, CancelWithTwoTransactionsActive) { + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + EXPECT_EQ(2u, num_running_dispatcher_jobs()); + requests_[0]->Cancel(); + + // Dispatcher state checked in TearDown. +} + +// Delete a resolver with some active requests and some queued requests. +TEST_F(HostResolverImplDnsTest, DeleteWithActiveTransactions) { + // At most 10 Jobs active at once. + CreateResolverWithLimitsAndParams( + PrioritizedDispatcher::Limits(NUM_PRIORITIES, 10u), + DefaultParams(proc_.get())); + + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + // First active job is an IPv4 request. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80, MEDIUM, + ADDRESS_FAMILY_IPV4)->Resolve()); + + // Add 10 more DNS lookups for different hostnames. First 4 should have two + // active jobs, next one has a single active job, and one pending. Others + // should all be queued. + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(ERR_IO_PENDING, CreateRequest( + base::StringPrintf("ok%i", i))->Resolve()); + } + EXPECT_EQ(10u, num_running_dispatcher_jobs()); + + resolver_.reset(); +} + +// Cancel a request with only the IPv6 transaction active. +TEST_F(HostResolverImplDnsTest, CancelWithIPv6TransactionActive) { + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("6slow_ok", 80)->Resolve()); + EXPECT_EQ(2u, num_running_dispatcher_jobs()); + + // The IPv4 request should complete, the IPv6 request is still pending. + base::RunLoop().RunUntilIdle(); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); + requests_[0]->Cancel(); + + // Dispatcher state checked in TearDown. +} + +// Cancel a request with only the IPv4 transaction pending. +TEST_F(HostResolverImplDnsTest, CancelWithIPv4TransactionPending) { + set_fallback_to_proctask(false); + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4slow_ok", 80)->Resolve()); + EXPECT_EQ(2u, num_running_dispatcher_jobs()); + + // The IPv6 request should complete, the IPv4 request is still pending. + base::RunLoop().RunUntilIdle(); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); + + requests_[0]->Cancel(); +} + +// Test cases where AAAA completes first. +TEST_F(HostResolverImplDnsTest, AAAACompletesFirst) { + set_fallback_to_proctask(false); + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4slow_ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4slow_4ok", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4slow_4timeout", 80)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("4slow_6timeout", 80)->Resolve()); + + base::RunLoop().RunUntilIdle(); + EXPECT_FALSE(requests_[0]->completed()); + EXPECT_FALSE(requests_[1]->completed()); + EXPECT_FALSE(requests_[2]->completed()); + // The IPv6 of the third request should have failed and resulted in cancelling + // the IPv4 request. + EXPECT_TRUE(requests_[3]->completed()); + EXPECT_EQ(ERR_DNS_TIMED_OUT, requests_[3]->result()); + EXPECT_EQ(3u, num_running_dispatcher_jobs()); + + dns_client_->CompleteDelayedTransactions(); + EXPECT_TRUE(requests_[0]->completed()); + EXPECT_EQ(OK, requests_[0]->result()); + EXPECT_EQ(2u, requests_[0]->NumberOfAddresses()); + EXPECT_TRUE(requests_[0]->HasAddress("127.0.0.1", 80)); + EXPECT_TRUE(requests_[0]->HasAddress("::1", 80)); + + EXPECT_TRUE(requests_[1]->completed()); + EXPECT_EQ(OK, requests_[1]->result()); + EXPECT_EQ(1u, requests_[1]->NumberOfAddresses()); + EXPECT_TRUE(requests_[1]->HasAddress("127.0.0.1", 80)); + + EXPECT_TRUE(requests_[2]->completed()); + EXPECT_EQ(ERR_DNS_TIMED_OUT, requests_[2]->result()); +} + +// Test the case where only a single transaction slot is available. +TEST_F(HostResolverImplDnsTest, SerialResolver) { + CreateSerialResolver(); + set_fallback_to_proctask(false); + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80)->Resolve()); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); + + base::RunLoop().RunUntilIdle(); + EXPECT_TRUE(requests_[0]->completed()); + EXPECT_EQ(OK, requests_[0]->result()); + EXPECT_EQ(2u, requests_[0]->NumberOfAddresses()); + EXPECT_TRUE(requests_[0]->HasAddress("127.0.0.1", 80)); + EXPECT_TRUE(requests_[0]->HasAddress("::1", 80)); +} + +// Test the case where the AAAA query is started when another transaction +// completes. +TEST_F(HostResolverImplDnsTest, AAAAStartsAfterOtherJobFinishes) { + CreateResolverWithLimitsAndParams( + PrioritizedDispatcher::Limits(NUM_PRIORITIES, 2), + DefaultParams(proc_.get())); + set_fallback_to_proctask(false); + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok", 80, MEDIUM, + ADDRESS_FAMILY_IPV4)->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, + CreateRequest("4slow_ok", 80, MEDIUM)->Resolve()); + // An IPv4 request should have been started pending for each job. + EXPECT_EQ(2u, num_running_dispatcher_jobs()); + + // Request 0's IPv4 request should complete, starting Request 1's IPv6 + // request, which should also complete. + base::RunLoop().RunUntilIdle(); + EXPECT_EQ(1u, num_running_dispatcher_jobs()); + EXPECT_TRUE(requests_[0]->completed()); + EXPECT_FALSE(requests_[1]->completed()); + + dns_client_->CompleteDelayedTransactions(); + EXPECT_TRUE(requests_[1]->completed()); + EXPECT_EQ(OK, requests_[1]->result()); + EXPECT_EQ(2u, requests_[1]->NumberOfAddresses()); + EXPECT_TRUE(requests_[1]->HasAddress("127.0.0.1", 80)); + EXPECT_TRUE(requests_[1]->HasAddress("::1", 80)); +} + +// Tests the case that a Job with a single transaction receives an empty address +// list, triggering fallback to ProcTask. +TEST_F(HostResolverImplDnsTest, IPv4EmptyFallback) { + ChangeDnsConfig(CreateValidDnsConfig()); + proc_->AddRuleForAllFamilies("empty_fallback", "192.168.0.1"); + proc_->SignalMultiple(1u); + EXPECT_EQ(ERR_IO_PENDING, + CreateRequest("empty_fallback", 80, MEDIUM, + ADDRESS_FAMILY_IPV4)->Resolve()); + EXPECT_EQ(OK, requests_[0]->WaitForResult()); + EXPECT_TRUE(requests_[0]->HasOneAddress("192.168.0.1", 80)); +} + +// Tests the case that a Job with two transactions receives two empty address +// lists, triggering fallback to ProcTask. +TEST_F(HostResolverImplDnsTest, UnspecEmptyFallback) { + ChangeDnsConfig(CreateValidDnsConfig()); + proc_->AddRuleForAllFamilies("empty_fallback", "192.168.0.1"); + proc_->SignalMultiple(1u); + EXPECT_EQ(ERR_IO_PENDING, + CreateRequest("empty_fallback", 80, MEDIUM, + ADDRESS_FAMILY_UNSPECIFIED)->Resolve()); + EXPECT_EQ(OK, requests_[0]->WaitForResult()); + EXPECT_TRUE(requests_[0]->HasOneAddress("192.168.0.1", 80)); +} + +// Tests getting a new invalid DnsConfig while there are active DnsTasks. +TEST_F(HostResolverImplDnsTest, InvalidDnsConfigWithPendingRequests) { + // At most 3 jobs active at once. This number is important, since we want to + // make sure that aborting the first HostResolverImpl::Job does not trigger + // another DnsTransaction on the second Job when it releases its second + // prioritized dispatcher slot. + CreateResolverWithLimitsAndParams( + PrioritizedDispatcher::Limits(NUM_PRIORITIES, 3u), + DefaultParams(proc_.get())); + + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies("slow_nx1", "192.168.0.1"); + proc_->AddRuleForAllFamilies("slow_nx2", "192.168.0.2"); + proc_->AddRuleForAllFamilies("ok", "192.168.0.3"); + + // First active job gets two slots. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_nx1")->Resolve()); + // Next job gets one slot, and waits on another. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_nx2")->Resolve()); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok")->Resolve()); + + EXPECT_EQ(3u, num_running_dispatcher_jobs()); + + // Clear DNS config. Two in-progress jobs should be aborted, and the next one + // should use a ProcTask. + ChangeDnsConfig(DnsConfig()); + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[0]->WaitForResult()); + EXPECT_EQ(ERR_NETWORK_CHANGED, requests_[1]->WaitForResult()); + + // Finish up the third job. Should bypass the DnsClient, and get its results + // from MockHostResolverProc. + EXPECT_FALSE(requests_[2]->completed()); + proc_->SignalMultiple(1u); + EXPECT_EQ(OK, requests_[2]->WaitForResult()); + EXPECT_TRUE(requests_[2]->HasOneAddress("192.168.0.3", 80)); +} + +// Tests the case that DnsClient is automatically disabled due to failures +// while there are active DnsTasks. +TEST_F(HostResolverImplDnsTest, + AutomaticallyDisableDnsClientWithPendingRequests) { + // Trying different limits is important for this test: Different limits + // result in different behavior when aborting in-progress DnsTasks. Having + // a DnsTask that has one job active and one in the queue when another job + // occupying two slots has its DnsTask aborted is the case most likely to run + // into problems. + for (size_t limit = 1u; limit < 6u; ++limit) { + CreateResolverWithLimitsAndParams( + PrioritizedDispatcher::Limits(NUM_PRIORITIES, limit), + DefaultParams(proc_.get())); + + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + // Queue up enough failures to disable DnsTasks. These will all fall back + // to ProcTasks, and succeed there. + for (unsigned i = 0u; i < maximum_dns_failures(); ++i) { + std::string host = base::StringPrintf("nx%u", i); + proc_->AddRuleForAllFamilies(host, "192.168.0.1"); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest(host)->Resolve()); + } + + // These requests should all bypass DnsTasks, due to the above failures, + // so should end up using ProcTasks. + proc_->AddRuleForAllFamilies("slow_ok1", "192.168.0.2"); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_ok1")->Resolve()); + proc_->AddRuleForAllFamilies("slow_ok2", "192.168.0.3"); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_ok2")->Resolve()); + proc_->AddRuleForAllFamilies("slow_ok3", "192.168.0.4"); + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_ok3")->Resolve()); + proc_->SignalMultiple(maximum_dns_failures() + 3); + + for (size_t i = 0u; i < maximum_dns_failures(); ++i) { + EXPECT_EQ(OK, requests_[i]->WaitForResult()); + EXPECT_TRUE(requests_[i]->HasOneAddress("192.168.0.1", 80)); + } + + EXPECT_EQ(OK, requests_[maximum_dns_failures()]->WaitForResult()); + EXPECT_TRUE(requests_[maximum_dns_failures()]->HasOneAddress( + "192.168.0.2", 80)); + EXPECT_EQ(OK, requests_[maximum_dns_failures() + 1]->WaitForResult()); + EXPECT_TRUE(requests_[maximum_dns_failures() + 1]->HasOneAddress( + "192.168.0.3", 80)); + EXPECT_EQ(OK, requests_[maximum_dns_failures() + 2]->WaitForResult()); + EXPECT_TRUE(requests_[maximum_dns_failures() + 2]->HasOneAddress( + "192.168.0.4", 80)); + requests_.clear(); + } +} + +// Tests a call to SetDnsClient while there are active DnsTasks. +TEST_F(HostResolverImplDnsTest, ManuallyDisableDnsClientWithPendingRequests) { + // At most 3 jobs active at once. This number is important, since we want to + // make sure that aborting the first HostResolverImpl::Job does not trigger + // another DnsTransaction on the second Job when it releases its second + // prioritized dispatcher slot. + CreateResolverWithLimitsAndParams( + PrioritizedDispatcher::Limits(NUM_PRIORITIES, 3u), + DefaultParams(proc_.get())); + + resolver_->SetDefaultAddressFamily(ADDRESS_FAMILY_UNSPECIFIED); + ChangeDnsConfig(CreateValidDnsConfig()); + + proc_->AddRuleForAllFamilies("slow_ok1", "192.168.0.1"); + proc_->AddRuleForAllFamilies("slow_ok2", "192.168.0.2"); + proc_->AddRuleForAllFamilies("ok", "192.168.0.3"); + + // First active job gets two slots. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_ok1")->Resolve()); + // Next job gets one slot, and waits on another. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("slow_ok2")->Resolve()); + // Next one is queued. + EXPECT_EQ(ERR_IO_PENDING, CreateRequest("ok")->Resolve()); + + EXPECT_EQ(3u, num_running_dispatcher_jobs()); + + // Clear DnsClient. The two in-progress jobs should fall back to a ProcTask, + // and the next one should be started with a ProcTask. + resolver_->SetDnsClient(scoped_ptr<DnsClient>()); + + // All three in-progress requests should now be running a ProcTask. + EXPECT_EQ(3u, num_running_dispatcher_jobs()); + proc_->SignalMultiple(3u); + + EXPECT_EQ(OK, requests_[0]->WaitForResult()); + EXPECT_TRUE(requests_[0]->HasOneAddress("192.168.0.1", 80)); + EXPECT_EQ(OK, requests_[1]->WaitForResult()); + EXPECT_TRUE(requests_[1]->HasOneAddress("192.168.0.2", 80)); + EXPECT_EQ(OK, requests_[2]->WaitForResult()); + EXPECT_TRUE(requests_[2]->HasOneAddress("192.168.0.3", 80)); +} + } // namespace net diff --git a/chromium/net/dns/mapped_host_resolver.cc b/chromium/net/dns/mapped_host_resolver.cc index 4db7bc97928..7b182077cfb 100644 --- a/chromium/net/dns/mapped_host_resolver.cc +++ b/chromium/net/dns/mapped_host_resolver.cc @@ -19,6 +19,7 @@ MappedHostResolver::~MappedHostResolver() { } int MappedHostResolver::Resolve(const RequestInfo& original_info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -28,7 +29,7 @@ int MappedHostResolver::Resolve(const RequestInfo& original_info, if (rv != OK) return rv; - return impl_->Resolve(info, addresses, callback, out_req, net_log); + return impl_->Resolve(info, priority, addresses, callback, out_req, net_log); } int MappedHostResolver::ResolveFromCache(const RequestInfo& original_info, diff --git a/chromium/net/dns/mapped_host_resolver.h b/chromium/net/dns/mapped_host_resolver.h index 50062a9848e..a121d4ed4c4 100644 --- a/chromium/net/dns/mapped_host_resolver.h +++ b/chromium/net/dns/mapped_host_resolver.h @@ -46,6 +46,7 @@ class NET_EXPORT MappedHostResolver : public HostResolver { // HostResolver methods: virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, diff --git a/chromium/net/dns/mapped_host_resolver_unittest.cc b/chromium/net/dns/mapped_host_resolver_unittest.cc index d8594663c9d..fbbfde8adff 100644 --- a/chromium/net/dns/mapped_host_resolver_unittest.cc +++ b/chromium/net/dns/mapped_host_resolver_unittest.cc @@ -40,10 +40,13 @@ TEST(MappedHostResolverTest, Inclusion) { // Try resolving "www.google.com:80". There are no mappings yet, so this // hits |resolver_impl| and fails. TestCompletionCallback callback; - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("www.google.com", 80)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); @@ -52,10 +55,13 @@ TEST(MappedHostResolverTest, Inclusion) { EXPECT_TRUE(resolver->AddRuleFromString("map *.google.com baz.com")); // Try resolving "www.google.com:80". Should be remapped to "baz.com:80". - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("www.google.com", 80)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -64,7 +70,10 @@ TEST(MappedHostResolverTest, Inclusion) { // Try resolving "foo.com:77". This will NOT be remapped, so result // is "foo.com:77". rv = resolver->Resolve(HostResolver::RequestInfo(HostPortPair("foo.com", 77)), - &address_list, callback.callback(), NULL, + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); @@ -75,10 +84,13 @@ TEST(MappedHostResolverTest, Inclusion) { EXPECT_TRUE(resolver->AddRuleFromString("Map *.org proxy:99")); // Try resolving "chromium.org:61". Should be remapped to "proxy:99". - rv = resolver->Resolve(HostResolver::RequestInfo - (HostPortPair("chromium.org", 61)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("chromium.org", 61)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -107,20 +119,26 @@ TEST(MappedHostResolverTest, Exclusion) { EXPECT_TRUE(resolver->AddRuleFromString("EXCLUDE *.google.com")); // Try resolving "www.google.com". Should not be remapped due to exclusion). - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("www.google.com", 80)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); EXPECT_EQ("192.168.1.3:80", FirstAddress(address_list)); // Try resolving "chrome.com:80". Should be remapped to "baz:80". - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("chrome.com", 80)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("chrome.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -145,20 +163,26 @@ TEST(MappedHostResolverTest, SetRulesFromString) { resolver->SetRulesFromString("map *.com baz , map *.net bar:60"); // Try resolving "www.google.com". Should be remapped to "baz". - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("www.google.com", 80)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); EXPECT_EQ("192.168.1.7:80", FirstAddress(address_list)); // Try resolving "chrome.net:80". Should be remapped to "bar:60". - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("chrome.net", 80)), - &address_list, callback.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("chrome.net", 80)), + DEFAULT_PRIORITY, + &address_list, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -196,18 +220,24 @@ TEST(MappedHostResolverTest, MapToError) { // Try resolving www.google.com --> Should give an error. TestCompletionCallback callback1; - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("www.google.com", 80)), - &address_list, callback1.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback1.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); // Try resolving www.foo.com --> Should succeed. TestCompletionCallback callback2; - rv = resolver->Resolve(HostResolver::RequestInfo( - HostPortPair("www.foo.com", 80)), - &address_list, callback2.callback(), NULL, - BoundNetLog()); + rv = resolver->Resolve( + HostResolver::RequestInfo(HostPortPair("www.foo.com", 80)), + DEFAULT_PRIORITY, + &address_list, + callback2.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback2.WaitForResult(); EXPECT_EQ(OK, rv); diff --git a/chromium/net/dns/mdns_client.cc b/chromium/net/dns/mdns_client.cc index 631b01a706e..d0273c5784a 100644 --- a/chromium/net/dns/mdns_client.cc +++ b/chromium/net/dns/mdns_client.cc @@ -4,14 +4,43 @@ #include "net/dns/mdns_client.h" +#include "net/dns/dns_protocol.h" #include "net/dns/mdns_client_impl.h" namespace net { +namespace { + +const char kMDnsMulticastGroupIPv4[] = "224.0.0.251"; +const char kMDnsMulticastGroupIPv6[] = "FF02::FB"; + +IPEndPoint GetMDnsIPEndPoint(const char* address) { + IPAddressNumber multicast_group_number; + bool success = ParseIPLiteralToNumber(address, + &multicast_group_number); + DCHECK(success); + return IPEndPoint(multicast_group_number, + dns_protocol::kDefaultPortMulticast); +} + +} // namespace + // static scoped_ptr<MDnsClient> MDnsClient::CreateDefault() { return scoped_ptr<MDnsClient>( new MDnsClientImpl(MDnsConnection::SocketFactory::CreateDefault())); } +IPEndPoint GetMDnsIPEndPoint(AddressFamily address_family) { + switch (address_family) { + case ADDRESS_FAMILY_IPV4: + return GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4); + case ADDRESS_FAMILY_IPV6: + return GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6); + default: + NOTREACHED(); + return IPEndPoint(); + } +} + } // namespace net diff --git a/chromium/net/dns/mdns_client.h b/chromium/net/dns/mdns_client.h index f12c6e292cd..71006d69b01 100644 --- a/chromium/net/dns/mdns_client.h +++ b/chromium/net/dns/mdns_client.h @@ -9,6 +9,7 @@ #include <vector> #include "base/callback.h" +#include "net/base/ip_endpoint.h" #include "net/dns/dns_query.h" #include "net/dns/dns_response.h" #include "net/dns/record_parsed.h" @@ -154,5 +155,7 @@ class NET_EXPORT MDnsClient { static scoped_ptr<MDnsClient> CreateDefault(); }; +IPEndPoint NET_EXPORT GetMDnsIPEndPoint(AddressFamily address_family); + } // namespace net #endif // NET_DNS_MDNS_CLIENT_H_ diff --git a/chromium/net/dns/mdns_client_impl.cc b/chromium/net/dns/mdns_client_impl.cc index 8f79edf4cdf..1dfc21b37c6 100644 --- a/chromium/net/dns/mdns_client_impl.cc +++ b/chromium/net/dns/mdns_client_impl.cc @@ -24,8 +24,6 @@ namespace net { namespace { -const char kMDnsMulticastGroupIPv4[] = "224.0.0.251"; -const char kMDnsMulticastGroupIPv6[] = "FF02::FB"; const unsigned MDnsTransactionTimeoutSeconds = 3; } @@ -41,11 +39,6 @@ MDnsConnection::SocketHandler::~SocketHandler() { } int MDnsConnection::SocketHandler::Start() { - int rv = BindSocket(); - if (rv != OK) { - return rv; - } - return DoLoop(0); } @@ -87,7 +80,7 @@ void MDnsConnection::SocketHandler::SendDone(int rv) { // TODO(noamsml): Retry logic. } -int MDnsConnection::SocketHandler::BindSocket() { +int MDnsConnection::SocketHandler::Bind() { IPAddressNumber address_any(multicast_addr_.address().size()); IPEndPoint bind_endpoint(address_any, multicast_addr_.port()); @@ -102,41 +95,51 @@ int MDnsConnection::SocketHandler::BindSocket() { return socket_->JoinGroup(multicast_addr_.address()); } -MDnsConnection::MDnsConnection(MDnsConnection::SocketFactory* socket_factory, - MDnsConnection::Delegate* delegate) - : socket_handler_ipv4_(this, - GetMDnsIPEndPoint(kMDnsMulticastGroupIPv4), - socket_factory), - socket_handler_ipv6_(this, - GetMDnsIPEndPoint(kMDnsMulticastGroupIPv6), - socket_factory), +MDnsConnection::MDnsConnection(MDnsConnection::Delegate* delegate) : delegate_(delegate) { } MDnsConnection::~MDnsConnection() { } -int MDnsConnection::Init() { - int rv; +bool MDnsConnection::Init(MDnsConnection::SocketFactory* socket_factory) { + // TODO(vitalybuka): crbug.com/297690 Make socket_factory return list + // of initialized sockets. + socket_handlers_.push_back( + new SocketHandler(this, GetMDnsIPEndPoint(ADDRESS_FAMILY_IPV4), + socket_factory)); + socket_handlers_.push_back( + new SocketHandler(this, GetMDnsIPEndPoint(ADDRESS_FAMILY_IPV6), + socket_factory)); - rv = socket_handler_ipv4_.Start(); - if (rv != OK) return rv; - rv = socket_handler_ipv6_.Start(); - if (rv != OK) return rv; + for (size_t i = 0; i < socket_handlers_.size();) { + if (socket_handlers_[i]->Bind() != OK) { + socket_handlers_.erase(socket_handlers_.begin() + i); + } else { + ++i; + } + } - return OK; + // All unbound sockets need to be bound before processing untrusted input. + // This is done for security reasons, so that an attacker can't get an unbound + // socket. + for (size_t i = 0; i < socket_handlers_.size();) { + if (socket_handlers_[i]->Start() != OK) { + socket_handlers_.erase(socket_handlers_.begin() + i); + } else { + ++i; + } + } + return !socket_handlers_.empty(); } -int MDnsConnection::Send(IOBuffer* buffer, unsigned size) { - int rv; - - rv = socket_handler_ipv4_.Send(buffer, size); - if (rv < OK && rv != ERR_IO_PENDING) return rv; - - rv = socket_handler_ipv6_.Send(buffer, size); - if (rv < OK && rv != ERR_IO_PENDING) return rv; - - return OK; +bool MDnsConnection::Send(IOBuffer* buffer, unsigned size) { + bool success = false; + for (size_t i = 0; i < socket_handlers_.size(); ++i) { + int rv = socket_handlers_[i]->Send(buffer, size); + success = success || (rv >= OK || rv == ERR_IO_PENDING); + } + return success; } void MDnsConnection::OnError(SocketHandler* loop, @@ -146,15 +149,6 @@ void MDnsConnection::OnError(SocketHandler* loop, delegate_->OnConnectionError(error); } -IPEndPoint MDnsConnection::GetMDnsIPEndPoint(const char* address) { - IPAddressNumber multicast_group_number; - bool success = ParseIPLiteralToNumber(address, - &multicast_group_number); - DCHECK(success); - return IPEndPoint(multicast_group_number, - dns_protocol::kDefaultPortMulticast); -} - void MDnsConnection::OnDatagramReceived( DnsResponse* response, const IPEndPoint& recv_addr, @@ -192,17 +186,16 @@ MDnsConnection::SocketFactory::CreateDefault() { new MDnsConnectionSocketFactoryImpl); } -MDnsClientImpl::Core::Core(MDnsClientImpl* client, - MDnsConnection::SocketFactory* socket_factory) - : client_(client), connection_(new MDnsConnection(socket_factory, this)) { +MDnsClientImpl::Core::Core(MDnsClientImpl* client) + : client_(client), connection_(new MDnsConnection(this)) { } MDnsClientImpl::Core::~Core() { STLDeleteValues(&listeners_); } -bool MDnsClientImpl::Core::Init() { - return connection_->Init() == OK; +bool MDnsClientImpl::Core::Init(MDnsConnection::SocketFactory* socket_factory) { + return connection_->Init(socket_factory); } bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { @@ -213,7 +206,7 @@ bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) { DnsQuery query(0, name_dns, rrtype); query.set_flags(0); // Remove the RD flag from the query. It is unneeded. - return connection_->Send(query.io_buffer(), query.io_buffer()->size()) == OK; + return connection_->Send(query.io_buffer(), query.io_buffer()->size()); } void MDnsClientImpl::Core::HandlePacket(DnsResponse* response, @@ -378,7 +371,7 @@ void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { observer_list_iterator->second->RemoveObserver(listener); // Remove the observer list from the map if it is empty - if (observer_list_iterator->second->size() == 0) { + if (!observer_list_iterator->second->might_have_observers()) { // Schedule the actual removal for later in case the listener removal // happens while iterating over the observer list. base::MessageLoop::current()->PostTask( @@ -389,7 +382,7 @@ void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) { void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { ListenerMap::iterator found = listeners_.find(key); - if (found != listeners_.end() && found->second->size() == 0) { + if (found != listeners_.end() && !found->second->might_have_observers()) { delete found->second; listeners_.erase(found); } @@ -442,8 +435,8 @@ MDnsClientImpl::~MDnsClientImpl() { bool MDnsClientImpl::StartListening() { DCHECK(!core_.get()); - core_.reset(new Core(this, socket_factory_.get())); - if (!core_->Init()) { + core_.reset(new Core(this)); + if (!core_->Init(socket_factory_.get())) { core_.reset(); return false; } diff --git a/chromium/net/dns/mdns_client_impl.h b/chromium/net/dns/mdns_client_impl.h index 9fe3f99e7dd..b69677bb747 100644 --- a/chromium/net/dns/mdns_client_impl.h +++ b/chromium/net/dns/mdns_client_impl.h @@ -11,6 +11,7 @@ #include <vector> #include "base/cancelable_callback.h" +#include "base/memory/scoped_vector.h" #include "base/observer_list.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" @@ -43,13 +44,13 @@ class NET_EXPORT_PRIVATE MDnsConnection { virtual ~Delegate() {} }; - explicit MDnsConnection(SocketFactory* socket_factory, - MDnsConnection::Delegate* delegate); + explicit MDnsConnection(MDnsConnection::Delegate* delegate); virtual ~MDnsConnection(); - int Init(); - int Send(IOBuffer* buffer, unsigned size); + // Both methods return true if at least one of the socket handlers succeeded. + bool Init(MDnsConnection::SocketFactory* socket_factory); + bool Send(IOBuffer* buffer, unsigned size); private: class SocketHandler { @@ -58,13 +59,13 @@ class NET_EXPORT_PRIVATE MDnsConnection { const IPEndPoint& multicast_addr, SocketFactory* socket_factory); ~SocketHandler(); - int DoLoop(int rv); + int Bind(); int Start(); int Send(IOBuffer* buffer, unsigned size); private: - int BindSocket(); + int DoLoop(int rv); void OnDatagramReceived(int rv); // Callback for when sending a query has finished. @@ -85,10 +86,8 @@ class NET_EXPORT_PRIVATE MDnsConnection { void OnError(SocketHandler* loop, int error); - IPEndPoint GetMDnsIPEndPoint(const char* address); - - SocketHandler socket_handler_ipv4_; - SocketHandler socket_handler_ipv6_; + // Only socket handlers which successfully bound and started are kept. + ScopedVector<SocketHandler> socket_handlers_; Delegate* delegate_; @@ -105,12 +104,11 @@ class NET_EXPORT_PRIVATE MDnsClientImpl : public MDnsClient { // invalidate the core. class Core : public base::SupportsWeakPtr<Core>, MDnsConnection::Delegate { public: - Core(MDnsClientImpl* client, - MDnsConnection::SocketFactory* socket_factory); + explicit Core(MDnsClientImpl* client); virtual ~Core(); // Initialize the core. Returns true on success. - bool Init(); + bool Init(MDnsConnection::SocketFactory* socket_factory); // Send a query with a specific rrtype and name. Returns true on success. bool SendQuery(uint16 rrtype, std::string name); diff --git a/chromium/net/dns/mdns_client_unittest.cc b/chromium/net/dns/mdns_client_unittest.cc index 324b4dfbee0..f524a5401c7 100644 --- a/chromium/net/dns/mdns_client_unittest.cc +++ b/chromium/net/dns/mdns_client_unittest.cc @@ -1025,24 +1025,17 @@ class SimpleMockSocketFactory } virtual scoped_ptr<DatagramServerSocket> CreateSocket() OVERRIDE { - scoped_ptr<MockMDnsDatagramServerSocket> socket( - new StrictMock<MockMDnsDatagramServerSocket>); - sockets_.push(socket.get()); - return socket.PassAs<DatagramServerSocket>(); + MockMDnsDatagramServerSocket* socket = sockets_.back(); + sockets_.weak_erase(sockets_.end() - 1); + return scoped_ptr<DatagramServerSocket>(socket); } - MockMDnsDatagramServerSocket* PopFirstSocket() { - MockMDnsDatagramServerSocket* socket = sockets_.front(); - sockets_.pop(); - return socket; - } - - size_t num_sockets() { - return sockets_.size(); + void PushSocket(MockMDnsDatagramServerSocket* socket) { + sockets_.push_back(socket); } private: - std::queue<MockMDnsDatagramServerSocket*> sockets_; + ScopedVector<MockMDnsDatagramServerSocket> sockets_; }; class MockMDnsConnectionDelegate : public MDnsConnection::Delegate { @@ -1058,16 +1051,16 @@ class MockMDnsConnectionDelegate : public MDnsConnection::Delegate { class MDnsConnectionTest : public ::testing::Test { public: - MDnsConnectionTest() : connection_(&factory_, &delegate_) { + MDnsConnectionTest() : connection_(&delegate_) { } protected: // Follow successful connection initialization. virtual void SetUp() OVERRIDE { - ASSERT_EQ(2u, factory_.num_sockets()); - - socket_ipv4_ = factory_.PopFirstSocket(); - socket_ipv6_ = factory_.PopFirstSocket(); + socket_ipv4_ = new MockMDnsDatagramServerSocket; + socket_ipv6_ = new MockMDnsDatagramServerSocket; + factory_.PushSocket(socket_ipv6_); + factory_.PushSocket(socket_ipv4_); } bool InitConnection() { @@ -1087,7 +1080,7 @@ class MDnsConnectionTest : public ::testing::Test { EXPECT_CALL(*socket_ipv6_, JoinGroupInternal("ff02::fb")) .WillOnce(Return(OK)); - return connection_.Init() == OK; + return connection_.Init(&factory_); } StrictMock<MockMDnsConnectionDelegate> delegate_; diff --git a/chromium/net/dns/mock_host_resolver.cc b/chromium/net/dns/mock_host_resolver.cc index b3d1489c9d7..25918ba3247 100644 --- a/chromium/net/dns/mock_host_resolver.cc +++ b/chromium/net/dns/mock_host_resolver.cc @@ -67,11 +67,13 @@ MockHostResolverBase::~MockHostResolverBase() { } int MockHostResolverBase::Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* handle, const BoundNetLog& net_log) { DCHECK(CalledOnValidThread()); + last_request_priority_ = priority; num_resolve_++; size_t id = next_request_id_++; int rv = ResolveFromIPLiteralOrCache(info, addresses); @@ -135,7 +137,8 @@ void MockHostResolverBase::ResolveAllPending() { // start id from 1 to distinguish from NULL RequestHandle MockHostResolverBase::MockHostResolverBase(bool use_caching) - : synchronous_mode_(false), + : last_request_priority_(DEFAULT_PRIORITY), + synchronous_mode_(false), ondemand_mode_(false), next_request_id_(1), num_resolve_(0), @@ -313,12 +316,14 @@ int RuleBasedHostResolverProc::Resolve(const std::string& host, bool matches_address_family = r->address_family == ADDRESS_FAMILY_UNSPECIFIED || r->address_family == address_family; + // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on + // whether a rule matches. + HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY; // Flags match if all of the bitflags in host_resolver_flags are enabled // in the rule's host_resolver_flags. However, the rule may have additional // flags specified, in which case the flags should still be considered a // match. - bool matches_flags = (r->host_resolver_flags & host_resolver_flags) == - host_resolver_flags; + bool matches_flags = (r->host_resolver_flags & flags) == flags; if (matches_flags && matches_address_family && MatchPattern(host, r->host_pattern)) { if (r->latency_ms != 0) { @@ -370,6 +375,7 @@ RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() { //----------------------------------------------------------------------------- int HangingHostResolver::Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, diff --git a/chromium/net/dns/mock_host_resolver.h b/chromium/net/dns/mock_host_resolver.h index 282521cc3b9..f8a424011a9 100644 --- a/chromium/net/dns/mock_host_resolver.h +++ b/chromium/net/dns/mock_host_resolver.h @@ -75,6 +75,7 @@ class MockHostResolverBase : public HostResolver, // HostResolver methods: virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -104,6 +105,12 @@ class MockHostResolverBase : public HostResolver, return num_resolve_from_cache_; } + // Returns the RequestPriority of the last call to Resolve() (or + // DEFAULT_PRIORITY if Resolve() hasn't been called yet). + RequestPriority last_request_priority() const { + return last_request_priority_; + } + protected: explicit MockHostResolverBase(bool use_caching); @@ -120,6 +127,7 @@ class MockHostResolverBase : public HostResolver, // Resolve request stored in |requests_|. Pass rv to callback. void ResolveNow(size_t id); + RequestPriority last_request_priority_; bool synchronous_mode_; bool ondemand_mode_; scoped_refptr<RuleBasedHostResolverProc> rules_; @@ -246,6 +254,7 @@ RuleBasedHostResolverProc* CreateCatchAllHostResolverProc(); class HangingHostResolver : public HostResolver { public: virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, diff --git a/chromium/net/dns/single_request_host_resolver.cc b/chromium/net/dns/single_request_host_resolver.cc index 31ef4c56b6e..7974abafeef 100644 --- a/chromium/net/dns/single_request_host_resolver.cc +++ b/chromium/net/dns/single_request_host_resolver.cc @@ -25,9 +25,11 @@ SingleRequestHostResolver::~SingleRequestHostResolver() { Cancel(); } -int SingleRequestHostResolver::Resolve( - const HostResolver::RequestInfo& info, AddressList* addresses, - const CompletionCallback& callback, const BoundNetLog& net_log) { +int SingleRequestHostResolver::Resolve(const HostResolver::RequestInfo& info, + RequestPriority priority, + AddressList* addresses, + const CompletionCallback& callback, + const BoundNetLog& net_log) { DCHECK(addresses); DCHECK_EQ(false, callback.is_null()); DCHECK(cur_request_callback_.is_null()) << "resolver already in use"; @@ -40,7 +42,7 @@ int SingleRequestHostResolver::Resolve( callback.is_null() ? CompletionCallback() : callback_; int rv = resolver_->Resolve( - info, addresses, transient_callback, &request, net_log); + info, priority, addresses, transient_callback, &request, net_log); if (rv == ERR_IO_PENDING) { DCHECK_EQ(false, callback.is_null()); diff --git a/chromium/net/dns/single_request_host_resolver.h b/chromium/net/dns/single_request_host_resolver.h index 52d01328911..3c8ef020edc 100644 --- a/chromium/net/dns/single_request_host_resolver.h +++ b/chromium/net/dns/single_request_host_resolver.h @@ -5,10 +5,18 @@ #ifndef NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_ #define NET_DNS_SINGLE_REQUEST_HOST_RESOLVER_H_ +#include "base/basictypes.h" + +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/request_priority.h" #include "net/dns/host_resolver.h" namespace net { +class AddressList; +class BoundNetLog; + // This class represents the task of resolving a hostname (or IP address // literal) to an AddressList object. It wraps HostResolver to resolve only a // single hostname at a time and cancels this request when going out of scope. @@ -25,6 +33,7 @@ class NET_EXPORT SingleRequestHostResolver { // Resolves the given hostname (or IP address literal), filling out the // |addresses| object upon success. See HostResolver::Resolve() for details. int Resolve(const HostResolver::RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, const BoundNetLog& net_log); diff --git a/chromium/net/dns/single_request_host_resolver_unittest.cc b/chromium/net/dns/single_request_host_resolver_unittest.cc index 1b0198f4fde..cc20bf3fe44 100644 --- a/chromium/net/dns/single_request_host_resolver_unittest.cc +++ b/chromium/net/dns/single_request_host_resolver_unittest.cc @@ -31,6 +31,7 @@ class HangingHostResolver : public HostResolver { } virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -77,7 +78,7 @@ TEST(SingleRequestHostResolverTest, NormalResolve) { TestCompletionCallback callback; HostResolver::RequestInfo request(HostPortPair("watsup", 90)); int rv = single_request_resolver.Resolve( - request, &addrlist, callback.callback(), BoundNetLog()); + request, DEFAULT_PRIORITY, &addrlist, callback.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback.WaitForResult()); @@ -97,8 +98,11 @@ TEST(SingleRequestHostResolverTest, Cancel) { AddressList addrlist; TestCompletionCallback callback; HostResolver::RequestInfo request(HostPortPair("watsup", 90)); - int rv = single_request_resolver.Resolve( - request, &addrlist, callback.callback(), BoundNetLog()); + int rv = single_request_resolver.Resolve(request, + DEFAULT_PRIORITY, + &addrlist, + callback.callback(), + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_TRUE(resolver.has_outstanding_request()); } diff --git a/chromium/net/ftp/ftp_directory_listing_parser_unittest.cc b/chromium/net/ftp/ftp_directory_listing_parser_unittest.cc index 4eaf1dd97e4..73d19bc6cac 100644 --- a/chromium/net/ftp/ftp_directory_listing_parser_unittest.cc +++ b/chromium/net/ftp/ftp_directory_listing_parser_unittest.cc @@ -42,8 +42,8 @@ TEST_P(FtpDirectoryListingParserTest, Parse) { SCOPED_TRACE(base::StringPrintf("Test case: %s", GetParam())); std::string test_listing; - EXPECT_TRUE(file_util::ReadFileToString(test_dir.AppendASCII(GetParam()), - &test_listing)); + EXPECT_TRUE(base::ReadFileToString(test_dir.AppendASCII(GetParam()), + &test_listing)); std::vector<FtpDirectoryListingEntry> entries; EXPECT_EQ(OK, ParseFtpDirectoryListing(test_listing, @@ -51,7 +51,7 @@ TEST_P(FtpDirectoryListingParserTest, Parse) { &entries)); std::string expected_listing; - ASSERT_TRUE(file_util::ReadFileToString( + ASSERT_TRUE(base::ReadFileToString( test_dir.AppendASCII(std::string(GetParam()) + ".expected"), &expected_listing)); diff --git a/chromium/net/ftp/ftp_network_transaction.cc b/chromium/net/ftp/ftp_network_transaction.cc index ccd6e2ef3b1..9d0cbe2c596 100644 --- a/chromium/net/ftp/ftp_network_transaction.cc +++ b/chromium/net/ftp/ftp_network_transaction.cc @@ -650,7 +650,9 @@ int FtpNetworkTransaction::DoCtrlResolveHost() { HostResolver::RequestInfo info(HostPortPair::FromURL(request_->url)); // No known referrer. return resolver_.Resolve( - info, &addresses_, + info, + DEFAULT_PRIORITY, + &addresses_, base::Bind(&FtpNetworkTransaction::OnIOComplete, base::Unretained(this)), net_log_); } @@ -663,8 +665,8 @@ int FtpNetworkTransaction::DoCtrlResolveHostComplete(int result) { int FtpNetworkTransaction::DoCtrlConnect() { next_state_ = STATE_CTRL_CONNECT_COMPLETE; - ctrl_socket_.reset(socket_factory_->CreateTransportClientSocket( - addresses_, net_log_.net_log(), net_log_.source())); + ctrl_socket_ = socket_factory_->CreateTransportClientSocket( + addresses_, net_log_.net_log(), net_log_.source()); net_log_.AddEvent( NetLog::TYPE_FTP_CONTROL_CONNECTION, ctrl_socket_->NetLog().source().ToEventParametersCallback()); @@ -1249,8 +1251,8 @@ int FtpNetworkTransaction::DoDataConnect() { return Stop(rv); data_address = AddressList::CreateFromIPAddress( ip_endpoint.address(), data_connection_port_); - data_socket_.reset(socket_factory_->CreateTransportClientSocket( - data_address, net_log_.net_log(), net_log_.source())); + data_socket_ = socket_factory_->CreateTransportClientSocket( + data_address, net_log_.net_log(), net_log_.source()); net_log_.AddEvent( NetLog::TYPE_FTP_DATA_CONNECTION, data_socket_->NetLog().source().ToEventParametersCallback()); diff --git a/chromium/net/http/http_auth_cache.cc b/chromium/net/http/http_auth_cache.cc index 79ea7fd625b..1c8c03fbb25 100644 --- a/chromium/net/http/http_auth_cache.cc +++ b/chromium/net/http/http_auth_cache.cc @@ -43,8 +43,7 @@ bool IsEnclosingPath(const std::string& container, const std::string& path) { void CheckOriginIsValid(const GURL& origin) { DCHECK(origin.is_valid()); // Note that the scheme may be FTP when we're using a HTTP proxy. - DCHECK(origin.SchemeIs("http") || origin.SchemeIs("https") || - origin.SchemeIs("ftp")); + DCHECK(origin.SchemeIsHTTPOrHTTPS() || origin.SchemeIs("ftp")); DCHECK(origin.GetOrigin() == origin); } diff --git a/chromium/net/http/http_auth_controller.cc b/chromium/net/http/http_auth_controller.cc index e9d6171ab5a..9cc57de35aa 100644 --- a/chromium/net/http/http_auth_controller.cc +++ b/chromium/net/http/http_auth_controller.cc @@ -570,4 +570,9 @@ void HttpAuthController::DisableAuthScheme(HttpAuth::Scheme scheme) { disabled_schemes_.insert(scheme); } +void HttpAuthController::DisableEmbeddedIdentity() { + DCHECK(CalledOnValidThread()); + embedded_identity_used_ = true; +} + } // namespace net diff --git a/chromium/net/http/http_auth_controller.h b/chromium/net/http/http_auth_controller.h index 5d5b469a265..fdc352554d2 100644 --- a/chromium/net/http/http_auth_controller.h +++ b/chromium/net/http/http_auth_controller.h @@ -71,6 +71,7 @@ class NET_EXPORT_PRIVATE HttpAuthController virtual bool IsAuthSchemeDisabled(HttpAuth::Scheme scheme) const; virtual void DisableAuthScheme(HttpAuth::Scheme scheme); + virtual void DisableEmbeddedIdentity(); private: // Actions for InvalidateCurrentHandler() diff --git a/chromium/net/http/http_auth_handler_negotiate.cc b/chromium/net/http/http_auth_handler_negotiate.cc index 21f5d65f1f2..13b106925b7 100644 --- a/chromium/net/http/http_auth_handler_negotiate.cc +++ b/chromium/net/http/http_auth_handler_negotiate.cc @@ -291,7 +291,9 @@ int HttpAuthHandlerNegotiate::DoResolveCanonicalName() { info.set_host_resolver_flags(HOST_RESOLVER_CANONNAME); single_resolve_.reset(new SingleRequestHostResolver(resolver_)); return single_resolve_->Resolve( - info, &address_list_, + info, + DEFAULT_PRIORITY, + &address_list_, base::Bind(&HttpAuthHandlerNegotiate::OnIOComplete, base::Unretained(this)), net_log_); diff --git a/chromium/net/http/http_basic_stream.cc b/chromium/net/http/http_basic_stream.cc index d70ac02be74..c30e17d6203 100644 --- a/chromium/net/http/http_basic_stream.cc +++ b/chromium/net/http/http_basic_stream.cc @@ -127,4 +127,8 @@ void HttpBasicStream::Drain(HttpNetworkSession* session) { // |drainer| will delete itself. } +void HttpBasicStream::SetPriority(RequestPriority priority) { + // TODO(akalin): Plumb this through to |connection_|. +} + } // namespace net diff --git a/chromium/net/http/http_basic_stream.h b/chromium/net/http/http_basic_stream.h index 2d4bb65ad32..2057837e9a9 100644 --- a/chromium/net/http/http_basic_stream.h +++ b/chromium/net/http/http_basic_stream.h @@ -81,6 +81,8 @@ class HttpBasicStream : public HttpStream { virtual void Drain(HttpNetworkSession* session) OVERRIDE; + virtual void SetPriority(RequestPriority priority) OVERRIDE; + private: scoped_refptr<GrowableIOBuffer> read_buf_; diff --git a/chromium/net/http/http_cache.cc b/chromium/net/http/http_cache.cc index 4cdcbb6cd88..49591da8ff4 100644 --- a/chromium/net/http/http_cache.cc +++ b/chromium/net/http/http_cache.cc @@ -392,7 +392,7 @@ void HttpCache::CloseAllConnections() { HttpNetworkSession* session = network->GetSession(); if (session) session->CloseAllConnections(); - } +} void HttpCache::CloseIdleConnections() { net::HttpNetworkLayer* network = @@ -595,6 +595,9 @@ int HttpCache::AsyncDoomEntry(const std::string& key, Transaction* trans) { } void HttpCache::DoomMainEntryForUrl(const GURL& url) { + if (!disk_cache_) + return; + HttpRequestInfo temp_info; temp_info.url = url; temp_info.method = "GET"; diff --git a/chromium/net/http/http_cache_transaction.cc b/chromium/net/http/http_cache_transaction.cc index d1307012208..83efcb135cd 100644 --- a/chromium/net/http/http_cache_transaction.cc +++ b/chromium/net/http/http_cache_transaction.cc @@ -196,14 +196,10 @@ HttpCache::Transaction::Transaction( io_callback_(base::Bind(&Transaction::OnIOComplete, weak_factory_.GetWeakPtr())), transaction_pattern_(PATTERN_UNDEFINED), - defer_cache_sensitivity_delay_(false), transaction_delegate_(transaction_delegate) { COMPILE_ASSERT(HttpCache::Transaction::kNumValidationHeaders == arraysize(kValidationHeaders), Invalid_number_of_validation_headers); - base::StringToInt( - base::FieldTrialList::FindFullName("CacheSensitivityAnalysis"), - &sensitivity_analysis_percent_increase_); } HttpCache::Transaction::~Transaction() { @@ -212,14 +208,12 @@ HttpCache::Transaction::~Transaction() { callback_.Reset(); transaction_delegate_ = NULL; - cache_io_start_ = base::TimeTicks(); - deferred_cache_sensitivity_delay_ = base::TimeDelta(); - if (cache_.get()) { + if (cache_) { if (entry_) { - bool cancel_request = reading_; + bool cancel_request = reading_ && response_.headers; if (cancel_request) { - if (partial_.get()) { + if (partial_) { entry_->disk_entry->CancelSparseIO(); } else { cancel_request &= (response_.headers->response_code() == 200); @@ -231,14 +225,6 @@ HttpCache::Transaction::~Transaction() { cache_->RemovePendingTransaction(this); } } - - // Cancel any outstanding callbacks before we drop our reference to the - // HttpCache. This probably isn't strictly necessary, but might as well. - weak_factory_.InvalidateWeakPtrs(); - - // We could still have a cache read or write in progress, so we just null the - // cache_ pointer to signal that we are dead. See DoCacheReadCompleted. - cache_.reset(); } int HttpCache::Transaction::WriteMetadata(IOBuffer* buf, int buf_len, @@ -463,10 +449,16 @@ bool HttpCache::Transaction::GetFullRequestHeaders( void HttpCache::Transaction::DoneReading() { if (cache_.get() && entry_) { - DCHECK(reading_); DCHECK_NE(mode_, UPDATE); - if (mode_ & WRITE) + if (mode_ & WRITE) { DoneWritingToEntry(true); + } else if (mode_ & READ) { + // It is necessary to check mode_ & READ because it is possible + // for mode_ to be NONE and entry_ non-NULL with a write entry + // if StopCaching was called. + cache_->DoneReadingFromEntry(entry_, this); + entry_ = NULL; + } } } @@ -676,9 +668,6 @@ int HttpCache::Transaction::DoLoop(int result) { case STATE_ADD_TO_ENTRY_COMPLETE: rv = DoAddToEntryComplete(rv); break; - case STATE_ADD_TO_ENTRY_COMPLETE_AFTER_DELAY: - rv = DoAddToEntryCompleteAfterDelay(rv); - break; case STATE_START_PARTIAL_CACHE_VALIDATION: DCHECK_EQ(OK, rv); rv = DoStartPartialCacheValidation(); @@ -969,7 +958,7 @@ int HttpCache::Transaction::DoSuccessfulSendRequest() { mode_ = NONE; } - if (mode_ != NONE && request_->method == "POST" && + if (request_->method == "POST" && NonErrorResponse(new_response->headers->response_code())) { cache_->DoomMainEntryForUrl(request_->url); } @@ -1034,8 +1023,7 @@ int HttpCache::Transaction::DoOpenEntry() { net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_OPEN_ENTRY); first_cache_access_since_ = TimeTicks::Now(); ReportCacheActionStart(); - defer_cache_sensitivity_delay_ = true; - return ResetCacheIOStart(cache_->OpenEntry(cache_key_, &new_entry_, this)); + return cache_->OpenEntry(cache_key_, &new_entry_, this); } int HttpCache::Transaction::DoOpenEntryComplete(int result) { @@ -1087,8 +1075,7 @@ int HttpCache::Transaction::DoCreateEntry() { cache_pending_ = true; net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_CREATE_ENTRY); ReportCacheActionStart(); - defer_cache_sensitivity_delay_ = true; - return ResetCacheIOStart(cache_->CreateEntry(cache_key_, &new_entry_, this)); + return cache_->CreateEntry(cache_key_, &new_entry_, this); } int HttpCache::Transaction::DoCreateEntryComplete(int result) { @@ -1130,7 +1117,7 @@ int HttpCache::Transaction::DoDoomEntry() { first_cache_access_since_ = TimeTicks::Now(); net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_DOOM_ENTRY); ReportCacheActionStart(); - return ResetCacheIOStart(cache_->DoomEntry(cache_key_, this)); + return cache_->DoomEntry(cache_key_, this); } int HttpCache::Transaction::DoDoomEntryComplete(int result) { @@ -1154,8 +1141,6 @@ int HttpCache::Transaction::DoAddToEntry() { } int HttpCache::Transaction::DoAddToEntryComplete(int result) { - DCHECK(defer_cache_sensitivity_delay_); - defer_cache_sensitivity_delay_ = false; net_log_.EndEventWithNetErrorCode(NetLog::TYPE_HTTP_CACHE_ADD_TO_ENTRY, result); const TimeDelta entry_lock_wait = @@ -1172,18 +1157,6 @@ int HttpCache::Transaction::DoAddToEntryComplete(int result) { // If there is a failure, the cache should have taken care of new_entry_. new_entry_ = NULL; - next_state_ = STATE_ADD_TO_ENTRY_COMPLETE_AFTER_DELAY; - - if (deferred_cache_sensitivity_delay_ == base::TimeDelta()) - return result; - - base::TimeDelta delay = deferred_cache_sensitivity_delay_; - deferred_cache_sensitivity_delay_ = base::TimeDelta(); - ScheduleDelayedLoop(delay, result); - return ERR_IO_PENDING; -} - -int HttpCache::Transaction::DoAddToEntryCompleteAfterDelay(int result) { if (result == ERR_CACHE_RACE) { next_state_ = STATE_INIT_ENTRY; return OK; @@ -1212,8 +1185,7 @@ int HttpCache::Transaction::DoStartPartialCacheValidation() { return OK; next_state_ = STATE_COMPLETE_PARTIAL_CACHE_VALIDATION; - return ResetCacheIOStart( - partial_->ShouldValidateCache(entry_->disk_entry, io_callback_)); + return partial_->ShouldValidateCache(entry_->disk_entry, io_callback_); } int HttpCache::Transaction::DoCompletePartialCacheValidation(int result) { @@ -1338,8 +1310,7 @@ int HttpCache::Transaction::DoTruncateCachedData() { net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_WRITE_DATA); ReportCacheActionStart(); // Truncate the stream. - return ResetCacheIOStart( - WriteToEntry(kResponseContentIndex, 0, NULL, 0, io_callback_)); + return WriteToEntry(kResponseContentIndex, 0, NULL, 0, io_callback_); } int HttpCache::Transaction::DoTruncateCachedDataComplete(int result) { @@ -1363,8 +1334,7 @@ int HttpCache::Transaction::DoTruncateCachedMetadata() { if (net_log_.IsLoggingAllEvents()) net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_WRITE_INFO); ReportCacheActionStart(); - return ResetCacheIOStart( - WriteToEntry(kMetadataIndex, 0, NULL, 0, io_callback_)); + return WriteToEntry(kMetadataIndex, 0, NULL, 0, io_callback_); } int HttpCache::Transaction::DoTruncateCachedMetadataComplete(int result) { @@ -1376,10 +1346,6 @@ int HttpCache::Transaction::DoTruncateCachedMetadataComplete(int result) { } } - // If this response is a redirect, then we can stop writing now. (We don't - // need to cache the response body of a redirect.) - if (response_.headers->IsRedirect(NULL)) - DoneWritingToEntry(true); next_state_ = STATE_PARTIAL_HEADERS_RECEIVED; return OK; } @@ -1416,8 +1382,8 @@ int HttpCache::Transaction::DoCacheReadResponse() { net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_READ_INFO); ReportCacheActionStart(); - return ResetCacheIOStart(entry_->disk_entry->ReadData( - kResponseInfoIndex, 0, read_buf_.get(), io_buf_len_, io_callback_)); + return entry_->disk_entry->ReadData(kResponseInfoIndex, 0, read_buf_.get(), + io_buf_len_, io_callback_); } int HttpCache::Transaction::DoCacheReadResponseComplete(int result) { @@ -1513,12 +1479,10 @@ int HttpCache::Transaction::DoCacheReadMetadata() { net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_READ_INFO); ReportCacheActionStart(); - return ResetCacheIOStart( - entry_->disk_entry->ReadData(kMetadataIndex, - 0, - response_.metadata.get(), - response_.metadata->size(), - io_callback_)); + return entry_->disk_entry->ReadData(kMetadataIndex, 0, + response_.metadata.get(), + response_.metadata->size(), + io_callback_); } int HttpCache::Transaction::DoCacheReadMetadataComplete(int result) { @@ -1531,10 +1495,7 @@ int HttpCache::Transaction::DoCacheReadMetadataComplete(int result) { int HttpCache::Transaction::DoCacheQueryData() { next_state_ = STATE_CACHE_QUERY_DATA_COMPLETE; - - // Balanced in DoCacheQueryDataComplete. - return ResetCacheIOStart( - entry_->disk_entry->ReadyForSparseIO(io_callback_)); + return entry_->disk_entry->ReadyForSparseIO(io_callback_); } int HttpCache::Transaction::DoCacheQueryDataComplete(int result) { @@ -1567,15 +1528,13 @@ int HttpCache::Transaction::DoCacheReadData() { net_log_.BeginEvent(NetLog::TYPE_HTTP_CACHE_READ_DATA); ReportCacheActionStart(); if (partial_.get()) { - return ResetCacheIOStart(partial_->CacheRead( - entry_->disk_entry, read_buf_.get(), io_buf_len_, io_callback_)); + return partial_->CacheRead(entry_->disk_entry, read_buf_.get(), io_buf_len_, + io_callback_); } - return ResetCacheIOStart(entry_->disk_entry->ReadData(kResponseContentIndex, - read_offset_, - read_buf_.get(), - io_buf_len_, - io_callback_)); + return entry_->disk_entry->ReadData(kResponseContentIndex, read_offset_, + read_buf_.get(), io_buf_len_, + io_callback_); } int HttpCache::Transaction::DoCacheReadDataComplete(int result) { @@ -1616,8 +1575,7 @@ int HttpCache::Transaction::DoCacheWriteData(int num_bytes) { ReportCacheActionStart(); } - return ResetCacheIOStart( - AppendResponseDataToEntry(read_buf_.get(), num_bytes, io_callback_)); + return AppendResponseDataToEntry(read_buf_.get(), num_bytes, io_callback_); } int HttpCache::Transaction::DoCacheWriteDataComplete(int result) { @@ -2302,8 +2260,8 @@ int HttpCache::Transaction::WriteResponseInfoToEntry(bool truncated) { data->Done(); io_buf_len_ = data->pickle()->size(); - return ResetCacheIOStart(entry_->disk_entry->WriteData( - kResponseInfoIndex, 0, data.get(), io_buf_len_, io_callback_, true)); + return entry_->disk_entry->WriteData(kResponseInfoIndex, 0, data.get(), + io_buf_len_, io_callback_, true); } int HttpCache::Transaction::AppendResponseDataToEntry( @@ -2400,6 +2358,15 @@ int HttpCache::Transaction::DoRestartPartialRequest() { return OK; } +void HttpCache::Transaction::ResetNetworkTransaction() { + DCHECK(!old_network_trans_load_timing_); + DCHECK(network_trans_); + LoadTimingInfo load_timing; + if (network_trans_->GetLoadTimingInfo(&load_timing)) + old_network_trans_load_timing_.reset(new LoadTimingInfo(load_timing)); + network_trans_.reset(); +} + // Histogram data from the end of 2010 show the following distribution of // response headers: // @@ -2431,73 +2398,6 @@ bool HttpCache::Transaction::CanResume(bool has_data) { } void HttpCache::Transaction::OnIOComplete(int result) { - if (!cache_io_start_.is_null()) { - base::TimeDelta cache_time = base::TimeTicks::Now() - cache_io_start_; - cache_io_start_ = base::TimeTicks(); - if (sensitivity_analysis_percent_increase_ > 0) { - cache_time *= sensitivity_analysis_percent_increase_; - cache_time /= 100; - if (!defer_cache_sensitivity_delay_) { - ScheduleDelayedLoop(cache_time, result); - return; - } else { - deferred_cache_sensitivity_delay_ += cache_time; - } - } - } - DCHECK(cache_io_start_.is_null()); - DoLoop(result); -} - -void HttpCache::Transaction::ScheduleDelayedLoop(base::TimeDelta delay, - int result) { - base::MessageLoop::current()->PostDelayedTask( - FROM_HERE, - base::Bind(&HttpCache::Transaction::RunDelayedLoop, - weak_factory_.GetWeakPtr(), - base::TimeTicks::Now(), - delay, - result), - delay); -} - -void HttpCache::Transaction::RunDelayedLoop(base::TimeTicks delay_start_time, - base::TimeDelta intended_delay, - int result) { - base::TimeDelta actual_delay = base::TimeTicks::Now() - delay_start_time; - int64 ratio; - int64 inverse_ratio; - if (intended_delay.InMicroseconds() > 0) { - ratio = - 100 * actual_delay.InMicroseconds() / intended_delay.InMicroseconds(); - } else { - ratio = 0; - } - if (actual_delay.InMicroseconds() > 0) { - inverse_ratio = - 100 * intended_delay.InMicroseconds() / actual_delay.InMicroseconds(); - } else { - inverse_ratio = 0; - } - bool ratio_sample = base::RandInt(0, 99) < ratio; - bool inverse_ratio_sample = base::RandInt(0, 99) < inverse_ratio; - int intended_delay_ms = intended_delay.InMilliseconds(); - UMA_HISTOGRAM_COUNTS_10000( - "HttpCache.CacheSensitivityAnalysis_IntendedDelayMs", - intended_delay_ms); - if (ratio_sample) { - UMA_HISTOGRAM_COUNTS_10000( - "HttpCache.CacheSensitivityAnalysis_RatioByIntendedDelayMs", - intended_delay_ms); - } - if (inverse_ratio_sample) { - UMA_HISTOGRAM_COUNTS_10000( - "HttpCache.CacheSensitivityAnalysis_InverseRatioByIntendedDelayMs", - intended_delay_ms); - } - - DCHECK(cache_io_start_.is_null()); - DCHECK(deferred_cache_sensitivity_delay_ == base::TimeDelta()); DoLoop(result); } @@ -2607,20 +2507,4 @@ void HttpCache::Transaction::RecordHistograms() { } } -int HttpCache::Transaction::ResetCacheIOStart(int return_value) { - DCHECK(cache_io_start_.is_null()); - if (return_value == ERR_IO_PENDING) - cache_io_start_ = base::TimeTicks::Now(); - return return_value; -} - -void HttpCache::Transaction::ResetNetworkTransaction() { - DCHECK(!old_network_trans_load_timing_); - DCHECK(network_trans_); - LoadTimingInfo load_timing; - if (network_trans_->GetLoadTimingInfo(&load_timing)) - old_network_trans_load_timing_.reset(new LoadTimingInfo(load_timing)); - network_trans_.reset(); -} - } // namespace net diff --git a/chromium/net/http/http_cache_transaction.h b/chromium/net/http/http_cache_transaction.h index 0d70a256b87..b1f32bd820d 100644 --- a/chromium/net/http/http_cache_transaction.h +++ b/chromium/net/http/http_cache_transaction.h @@ -160,7 +160,6 @@ class HttpCache::Transaction : public HttpTransaction { STATE_DOOM_ENTRY_COMPLETE, STATE_ADD_TO_ENTRY, STATE_ADD_TO_ENTRY_COMPLETE, - STATE_ADD_TO_ENTRY_COMPLETE_AFTER_DELAY, STATE_START_PARTIAL_CACHE_VALIDATION, STATE_COMPLETE_PARTIAL_CACHE_VALIDATION, STATE_UPDATE_CACHED_RESPONSE, @@ -232,7 +231,6 @@ class HttpCache::Transaction : public HttpTransaction { int DoDoomEntryComplete(int result); int DoAddToEntry(); int DoAddToEntryComplete(int result); - int DoAddToEntryCompleteAfterDelay(int result); int DoStartPartialCacheValidation(); int DoCompletePartialCacheValidation(int result); int DoUpdateCachedResponse(); @@ -364,6 +362,10 @@ class HttpCache::Transaction : public HttpTransaction { // between the byte range request and the cached entry. int DoRestartPartialRequest(); + // Resets |network_trans_|, which must be non-NULL. Also updates + // |old_network_trans_load_timing_|, which must be NULL when this is called. + void ResetNetworkTransaction(); + // Returns true if we should bother attempting to resume this request if it // is aborted while in progress. If |has_data| is true, the size of the stored // data is considered for the result. @@ -379,19 +381,6 @@ class HttpCache::Transaction : public HttpTransaction { void UpdateTransactionPattern(TransactionPattern new_transaction_pattern); void RecordHistograms(); - // Resets cache_io_start_ to the current time, if |return_value| is - // ERR_IO_PENDING. - // Returns |return_value|. - int ResetCacheIOStart(int return_value); - - void ScheduleDelayedLoop(base::TimeDelta delay, int result); - void RunDelayedLoop(base::TimeTicks delay_start_time, - base::TimeDelta intended_delay, int result); - - // Resets |network_trans_|, which must be non-NULL. Also updates - // |old_network_trans_load_timing_|, which must be NULL when this is called. - void ResetNetworkTransaction(); - State next_state_; const HttpRequestInfo* request_; RequestPriority priority_; @@ -438,21 +427,6 @@ class HttpCache::Transaction : public HttpTransaction { base::TimeTicks first_cache_access_since_; base::TimeTicks send_request_since_; - // For sensitivity analysis (field trials emulating longer cache IO times), - // the time at which a cache IO action has started, or base::TimeTicks() - // if no cache IO action is currently in progress. - base::TimeTicks cache_io_start_; - - // For OpenEntry and CreateEntry, if sensitivity analysis would mandate - // a delay on return, we must defer that delay until AddToEntry has been - // called, to avoid a race condition on the address returned. - base::TimeDelta deferred_cache_sensitivity_delay_; - bool defer_cache_sensitivity_delay_; - - // For sensitivity analysis, the simulated increase in cache service times, - // in percent. - int sensitivity_analysis_percent_increase_; - HttpTransactionDelegate* transaction_delegate_; // Load timing information for the last network request, if any. Set in the diff --git a/chromium/net/http/http_cache_unittest.cc b/chromium/net/http/http_cache_unittest.cc index c2bfde7bce2..e47852c01ac 100644 --- a/chromium/net/http/http_cache_unittest.cc +++ b/chromium/net/http/http_cache_unittest.cc @@ -2895,6 +2895,67 @@ TEST(HttpCache, SimplePOST_Invalidate_205) { RemoveMockTransaction(&transaction); } +// Tests that a successful POST invalidates a previously cached GET, even when +// there is no upload identifier. +TEST(HttpCache, SimplePOST_NoUploadId_Invalidate_205) { + MockHttpCache cache; + + MockTransaction transaction(kSimpleGET_Transaction); + AddMockTransaction(&transaction); + MockHttpRequest req1(transaction); + + // Attempt to populate the cache. + RunTransactionTestWithRequest(cache.http_cache(), transaction, req1, NULL); + + EXPECT_EQ(1, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); + + ScopedVector<net::UploadElementReader> element_readers; + element_readers.push_back(new net::UploadBytesElementReader("hello", 5)); + net::UploadDataStream upload_data_stream(&element_readers, 0); + + transaction.method = "POST"; + transaction.status = "HTTP/1.1 205 No Content"; + MockHttpRequest req2(transaction); + req2.upload_data_stream = &upload_data_stream; + + RunTransactionTestWithRequest(cache.http_cache(), transaction, req2, NULL); + + EXPECT_EQ(2, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); + + RunTransactionTestWithRequest(cache.http_cache(), transaction, req1, NULL); + + EXPECT_EQ(3, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(2, cache.disk_cache()->create_count()); + RemoveMockTransaction(&transaction); +} + +// Tests that processing a POST before creating the backend doesn't crash. +TEST(HttpCache, SimplePOST_NoUploadId_NoBackend) { + // This will initialize a cache object with NULL backend. + MockBlockingBackendFactory* factory = new MockBlockingBackendFactory(); + factory->set_fail(true); + factory->FinishCreation(); + MockHttpCache cache(factory); + + ScopedVector<net::UploadElementReader> element_readers; + element_readers.push_back(new net::UploadBytesElementReader("hello", 5)); + net::UploadDataStream upload_data_stream(&element_readers, 0); + + MockTransaction transaction(kSimplePOST_Transaction); + AddMockTransaction(&transaction); + MockHttpRequest req(transaction); + req.upload_data_stream = &upload_data_stream; + + RunTransactionTestWithRequest(cache.http_cache(), transaction, req, NULL); + + RemoveMockTransaction(&transaction); +} + // Tests that we don't invalidate entries as a result of a failed POST. TEST(HttpCache, SimplePOST_DontInvalidate_100) { MockHttpCache cache; @@ -5333,7 +5394,7 @@ TEST(HttpCache, CachedRedirect) { MockHttpRequest request(kTestTransaction); net::TestCompletionCallback callback; - // write to the cache + // Write to the cache. { scoped_ptr<net::HttpTransaction> trans; int rv = cache.http_cache()->CreateTransaction( @@ -5355,6 +5416,9 @@ TEST(HttpCache, CachedRedirect) { info->headers->EnumerateHeader(NULL, "Location", &location); EXPECT_EQ(location, "http://www.bar.com/"); + // Mark the transaction as completed so it is cached. + trans->DoneReading(); + // Destroy transaction when going out of scope. We have not actually // read the response body -- want to test that it is still getting cached. } @@ -5362,7 +5426,12 @@ TEST(HttpCache, CachedRedirect) { EXPECT_EQ(0, cache.disk_cache()->open_count()); EXPECT_EQ(1, cache.disk_cache()->create_count()); - // read from the cache + // Active entries in the cache are not retired synchronously. Make + // sure the next run hits the MockHttpCache and open_count is + // correct. + base::MessageLoop::current()->RunUntilIdle(); + + // Read from the cache. { scoped_ptr<net::HttpTransaction> trans; int rv = cache.http_cache()->CreateTransaction( @@ -5384,6 +5453,9 @@ TEST(HttpCache, CachedRedirect) { info->headers->EnumerateHeader(NULL, "Location", &location); EXPECT_EQ(location, "http://www.bar.com/"); + // Mark the transaction as completed so it is cached. + trans->DoneReading(); + // Destroy transaction when going out of scope. We have not actually // read the response body -- want to test that it is still getting cached. } @@ -5392,6 +5464,61 @@ TEST(HttpCache, CachedRedirect) { EXPECT_EQ(1, cache.disk_cache()->create_count()); } +// Verify that no-cache resources are stored in cache, but are not fetched from +// cache during normal loads. +TEST(HttpCache, CacheControlNoCacheNormalLoad) { + MockHttpCache cache; + + ScopedMockTransaction transaction(kSimpleGET_Transaction); + transaction.response_headers = "cache-control: no-cache\n"; + + // Initial load. + RunTransactionTest(cache.http_cache(), transaction); + + EXPECT_EQ(1, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); + + // Try loading again; it should result in a network fetch. + RunTransactionTest(cache.http_cache(), transaction); + + EXPECT_EQ(2, cache.network_layer()->transaction_count()); + EXPECT_EQ(1, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); + + disk_cache::Entry* entry; + EXPECT_TRUE(cache.OpenBackendEntry(transaction.url, &entry)); + entry->Close(); +} + +// Verify that no-cache resources are stored in cache and fetched from cache +// when the LOAD_PREFERRING_CACHE flag is set. +TEST(HttpCache, CacheControlNoCacheHistoryLoad) { + MockHttpCache cache; + + ScopedMockTransaction transaction(kSimpleGET_Transaction); + transaction.response_headers = "cache-control: no-cache\n"; + + // Initial load. + RunTransactionTest(cache.http_cache(), transaction); + + EXPECT_EQ(1, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); + + // Try loading again with LOAD_PREFERRING_CACHE. + transaction.load_flags = net::LOAD_PREFERRING_CACHE; + RunTransactionTest(cache.http_cache(), transaction); + + EXPECT_EQ(1, cache.network_layer()->transaction_count()); + EXPECT_EQ(1, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); + + disk_cache::Entry* entry; + EXPECT_TRUE(cache.OpenBackendEntry(transaction.url, &entry)); + entry->Close(); +} + TEST(HttpCache, CacheControlNoStore) { MockHttpCache cache; @@ -5783,7 +5910,40 @@ TEST(HttpCache, FilterCompletion) { EXPECT_EQ(1, cache.disk_cache()->create_count()); } -// Tests that we stop cachining when told. +// Tests that we don't mark entries as truncated and release the cache +// entry when DoneReading() is called before any Read() calls, such as +// for a redirect. +TEST(HttpCache, DoneReading) { + MockHttpCache cache; + net::TestCompletionCallback callback; + + ScopedMockTransaction transaction(kSimpleGET_Transaction); + transaction.data = ""; + + scoped_ptr<net::HttpTransaction> trans; + int rv = cache.http_cache()->CreateTransaction( + net::DEFAULT_PRIORITY, &trans, NULL); + EXPECT_EQ(net::OK, rv); + + MockHttpRequest request(transaction); + rv = trans->Start(&request, callback.callback(), net::BoundNetLog()); + EXPECT_EQ(net::OK, callback.GetResult(rv)); + + trans->DoneReading(); + // Leave the transaction around. + + // Make sure that the ActiveEntry is gone. + base::MessageLoop::current()->RunUntilIdle(); + + // Read from the cache. This should not deadlock. + RunTransactionTest(cache.http_cache(), transaction); + + EXPECT_EQ(1, cache.network_layer()->transaction_count()); + EXPECT_EQ(1, cache.disk_cache()->open_count()); + EXPECT_EQ(1, cache.disk_cache()->create_count()); +} + +// Tests that we stop caching when told. TEST(HttpCache, StopCachingDeletesEntry) { MockHttpCache cache; net::TestCompletionCallback callback; @@ -5800,7 +5960,7 @@ TEST(HttpCache, StopCachingDeletesEntry) { scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(256)); rv = trans->Read(buf.get(), 10, callback.callback()); - EXPECT_EQ(callback.GetResult(rv), 10); + EXPECT_EQ(10, callback.GetResult(rv)); trans->StopCaching(); @@ -5808,8 +5968,88 @@ TEST(HttpCache, StopCachingDeletesEntry) { rv = trans->Read(buf.get(), 256, callback.callback()); EXPECT_GT(callback.GetResult(rv), 0); rv = trans->Read(buf.get(), 256, callback.callback()); - EXPECT_EQ(callback.GetResult(rv), 0); + EXPECT_EQ(0, callback.GetResult(rv)); + } + + // Make sure that the ActiveEntry is gone. + base::MessageLoop::current()->RunUntilIdle(); + + // Verify that the entry is gone. + RunTransactionTest(cache.http_cache(), kSimpleGET_Transaction); + + EXPECT_EQ(2, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(2, cache.disk_cache()->create_count()); +} + +// Tests that we stop caching when told, even if DoneReading is called +// after StopCaching. +TEST(HttpCache, StopCachingThenDoneReadingDeletesEntry) { + MockHttpCache cache; + net::TestCompletionCallback callback; + MockHttpRequest request(kSimpleGET_Transaction); + + { + scoped_ptr<net::HttpTransaction> trans; + int rv = cache.http_cache()->CreateTransaction( + net::DEFAULT_PRIORITY, &trans, NULL); + EXPECT_EQ(net::OK, rv); + + rv = trans->Start(&request, callback.callback(), net::BoundNetLog()); + EXPECT_EQ(net::OK, callback.GetResult(rv)); + + scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(256)); + rv = trans->Read(buf.get(), 10, callback.callback()); + EXPECT_EQ(10, callback.GetResult(rv)); + + trans->StopCaching(); + + // We should be able to keep reading. + rv = trans->Read(buf.get(), 256, callback.callback()); + EXPECT_GT(callback.GetResult(rv), 0); + rv = trans->Read(buf.get(), 256, callback.callback()); + EXPECT_EQ(0, callback.GetResult(rv)); + + // We should be able to call DoneReading. + trans->DoneReading(); + } + + // Make sure that the ActiveEntry is gone. + base::MessageLoop::current()->RunUntilIdle(); + + // Verify that the entry is gone. + RunTransactionTest(cache.http_cache(), kSimpleGET_Transaction); + + EXPECT_EQ(2, cache.network_layer()->transaction_count()); + EXPECT_EQ(0, cache.disk_cache()->open_count()); + EXPECT_EQ(2, cache.disk_cache()->create_count()); +} + +// Tests that we stop caching when told, when using auth. +TEST(HttpCache, StopCachingWithAuthDeletesEntry) { + MockHttpCache cache; + net::TestCompletionCallback callback; + MockTransaction mock_transaction(kSimpleGET_Transaction); + mock_transaction.status = "HTTP/1.1 401 Unauthorized"; + AddMockTransaction(&mock_transaction); + MockHttpRequest request(mock_transaction); + + { + scoped_ptr<net::HttpTransaction> trans; + int rv = cache.http_cache()->CreateTransaction( + net::DEFAULT_PRIORITY, &trans, NULL); + EXPECT_EQ(net::OK, rv); + + rv = trans->Start(&request, callback.callback(), net::BoundNetLog()); + EXPECT_EQ(net::OK, callback.GetResult(rv)); + + trans->StopCaching(); + + scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(256)); + rv = trans->Read(buf.get(), 10, callback.callback()); + EXPECT_EQ(callback.GetResult(rv), 10); } + RemoveMockTransaction(&mock_transaction); // Make sure that the ActiveEntry is gone. base::MessageLoop::current()->RunUntilIdle(); diff --git a/chromium/net/http/http_network_layer_unittest.cc b/chromium/net/http/http_network_layer_unittest.cc index c939d844ba2..6d2ea358071 100644 --- a/chromium/net/http/http_network_layer_unittest.cc +++ b/chromium/net/http/http_network_layer_unittest.cc @@ -48,6 +48,169 @@ class HttpNetworkLayerTest : public PlatformTest { factory_.reset(new HttpNetworkLayer(network_session_.get())); } + void ExecuteRequestExpectingContentAndHeader(const std::string& content, + const std::string& header, + const std::string& value) { + TestCompletionCallback callback; + + HttpRequestInfo request_info; + request_info.url = GURL("http://www.google.com/"); + request_info.method = "GET"; + request_info.load_flags = LOAD_NORMAL; + + scoped_ptr<HttpTransaction> trans; + int rv = factory_->CreateTransaction(DEFAULT_PRIORITY, &trans, NULL); + EXPECT_EQ(OK, rv); + + rv = trans->Start(&request_info, callback.callback(), BoundNetLog()); + if (rv == ERR_IO_PENDING) + rv = callback.WaitForResult(); + ASSERT_EQ(OK, rv); + + std::string contents; + rv = ReadTransaction(trans.get(), &contents); + EXPECT_EQ(OK, rv); + EXPECT_EQ(content, contents); + + if (!header.empty()) { + // We also have a server header here that isn't set by the proxy. + EXPECT_TRUE(trans->GetResponseInfo()->headers->HasHeaderValue( + header, value)); + } + } + + // Check that |proxy_count| proxies are in the retry list. + // These will be, in order, "bad:8080" and "alsobad:8080". + void TestBadProxies(unsigned int proxy_count) { + const ProxyRetryInfoMap& retry_info = proxy_service_->proxy_retry_info(); + ASSERT_EQ(proxy_count, retry_info.size()); + ASSERT_TRUE(retry_info.find("bad:8080") != retry_info.end()); + if (proxy_count > 1) + ASSERT_TRUE(retry_info.find("alsobad:8080") != retry_info.end()); + } + + // Simulates a request through a proxy which returns a bypass, which is then + // retried through a second proxy that doesn't bypass. + // Checks that the expected requests were issued, the expected content was + // recieved, and the first proxy ("bad:8080") was marked as bad. + void TestProxyFallback() { + MockRead data_reads[] = { + MockRead("HTTP/1.1 200 OK\r\n" + "Connection: proxy-bypass\r\n\r\n"), + MockRead("Bypass message"), + MockRead(SYNCHRONOUS, OK), + }; + MockWrite data_writes[] = { + MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + + StaticSocketDataProvider data1(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes)); + mock_socket_factory_.AddSocketDataProvider(&data1); + + // Second data provider returns the expected content. + MockRead data_reads2[] = { + MockRead("HTTP/1.0 200 OK\r\n" + "Server: not-proxy\r\n\r\n"), + MockRead("content"), + MockRead(SYNCHRONOUS, OK), + }; + MockWrite data_writes2[] = { + MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2), + data_writes2, arraysize(data_writes2)); + mock_socket_factory_.AddSocketDataProvider(&data2); + + // Expect that we get "content" and not "Bypass message", and that there's + // a "not-proxy" "Server:" header in the final response. + ExecuteRequestExpectingContentAndHeader("content", "server", "not-proxy"); + + // We should also observe the bad proxy in the retry list. + TestBadProxies(1u); + } + + // Simulates a request through a proxy which returns a bypass, which is then + // retried through a direct connection to the origin site. + // Checks that the expected requests were issued, the expected content was + // received, and the proxy ("bad:8080") was marked as bad. + void TestProxyFallbackToDirect() { + MockRead data_reads[] = { + MockRead("HTTP/1.1 200 OK\r\n" + "Connection: proxy-bypass\r\n\r\n"), + MockRead("Bypass message"), + MockRead(SYNCHRONOUS, OK), + }; + MockWrite data_writes[] = { + MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + StaticSocketDataProvider data1(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes)); + mock_socket_factory_.AddSocketDataProvider(&data1); + + // Second data provider returns the expected content. + MockRead data_reads2[] = { + MockRead("HTTP/1.0 200 OK\r\n" + "Server: not-proxy\r\n\r\n"), + MockRead("content"), + MockRead(SYNCHRONOUS, OK), + }; + MockWrite data_writes2[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2), + data_writes2, arraysize(data_writes2)); + mock_socket_factory_.AddSocketDataProvider(&data2); + + // Expect that we get "content" and not "Bypass message", and that there's + // a "not-proxy" "Server:" header in the final response. + ExecuteRequestExpectingContentAndHeader("content", "server", "not-proxy"); + + // We should also observe the bad proxy in the retry list. + TestBadProxies(1u); + } + + // Simulates a request through a proxy which returns a bypass, under a + // configuration where there is no valid bypass. |proxy_count| proxies + // are expected to be configured. + // Checks that the expected requests were issued, the bypass message was the + // final received content, and all proxies were marked as bad. + void TestProxyFallbackFail(unsigned int proxy_count) { + MockRead data_reads[] = { + MockRead("HTTP/1.1 200 OK\r\n" + "Connection: proxy-bypass\r\n\r\n"), + MockRead("Bypass message"), + MockRead(SYNCHRONOUS, OK), + }; + MockWrite data_writes[] = { + MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + StaticSocketDataProvider data1(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes)); + StaticSocketDataProvider data2(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes)); + + mock_socket_factory_.AddSocketDataProvider(&data1); + if (proxy_count > 1) + mock_socket_factory_.AddSocketDataProvider(&data2); + + // Expect that we get "Bypass message", and not "content".. + ExecuteRequestExpectingContentAndHeader("Bypass message", "", ""); + + // We should also observe the bad proxy or proxies in the retry list. + TestBadProxies(proxy_count); + } + MockClientSocketFactory mock_socket_factory_; MockHostResolver host_resolver_; scoped_ptr<CertVerifier> cert_verifier_; @@ -125,76 +288,60 @@ TEST_F(HttpNetworkLayerTest, GET) { EXPECT_EQ("hello world", contents); } -TEST_F(HttpNetworkLayerTest, ServerFallback) { - // Verify that a Connection: Proxy-Bypass header induces proxy fallback to - // a second proxy, if configured. - - // To configure this test, we need to wire up a custom proxy service to use - // a pair of proxies. We'll induce fallback via the first and return - // the expected data via the second. +// Proxy bypass tests. These tests run through various server-induced +// proxy-bypass scenarios using both PAC file and fixed proxy params. +// The test scenarios are: +// - bypass with two proxies configured and the first but not the second +// is bypassed. +// - bypass with one proxy configured and an explicit fallback to direct +// connections +// - bypass with two proxies configured and both are bypassed +// - bypass with one proxy configured which is bypassed with no defined +// fallback + +TEST_F(HttpNetworkLayerTest, ServerTwoProxyBypassPac) { ConfigureTestDependencies(ProxyService::CreateFixedFromPacResult( "PROXY bad:8080; PROXY good:8080")); + TestProxyFallback(); +} - MockRead data_reads[] = { - MockRead("HTTP/1.1 200 OK\r\n" - "Connection: proxy-bypass\r\n\r\n"), - MockRead("Bypass message"), - MockRead(SYNCHRONOUS, OK), - }; - MockWrite data_writes[] = { - MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" - "Host: www.google.com\r\n" - "Proxy-Connection: keep-alive\r\n\r\n"), - }; - StaticSocketDataProvider data1(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes)); - mock_socket_factory_.AddSocketDataProvider(&data1); - - // Second data provider returns the expected content. - MockRead data_reads2[] = { - MockRead("HTTP/1.0 200 OK\r\n" - "Server: not-proxy\r\n\r\n"), - MockRead("content"), - MockRead(SYNCHRONOUS, OK), - }; - MockWrite data_writes2[] = { - MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" - "Host: www.google.com\r\n" - "Proxy-Connection: keep-alive\r\n\r\n"), - }; - StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2), - data_writes2, arraysize(data_writes2)); - mock_socket_factory_.AddSocketDataProvider(&data2); +TEST_F(HttpNetworkLayerTest, ServerTwoProxyBypassFixed) { + ConfigureTestDependencies(ProxyService::CreateFixed("bad:8080, good:8080")); + TestProxyFallback(); +} - TestCompletionCallback callback; +TEST_F(HttpNetworkLayerTest, ServerOneProxyWithDirectBypassPac) { + ConfigureTestDependencies(ProxyService::CreateFixedFromPacResult( + "PROXY bad:8080; DIRECT")); + TestProxyFallbackToDirect(); +} - HttpRequestInfo request_info; - request_info.url = GURL("http://www.google.com/"); - request_info.method = "GET"; - request_info.load_flags = LOAD_NORMAL; +TEST_F(HttpNetworkLayerTest, ServerOneProxyWithDirectBypassFixed) { + ConfigureTestDependencies(ProxyService::CreateFixed( "bad:8080, direct://")); + TestProxyFallbackToDirect(); +} - scoped_ptr<HttpTransaction> trans; - int rv = factory_->CreateTransaction(DEFAULT_PRIORITY, &trans, NULL); - EXPECT_EQ(OK, rv); +TEST_F(HttpNetworkLayerTest, ServerTwoProxyDoubleBypassPac) { + ConfigureTestDependencies(ProxyService::CreateFixedFromPacResult( + "PROXY bad:8080; PROXY alsobad:8080")); + TestProxyFallbackFail(2u); +} - rv = trans->Start(&request_info, callback.callback(), BoundNetLog()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - ASSERT_EQ(OK, rv); +TEST_F(HttpNetworkLayerTest, ServerTwoProxyDoubleBypassFixed) { + ConfigureTestDependencies(ProxyService::CreateFixed( + "bad:8080, alsobad:8080")); + TestProxyFallbackFail(2u); +} - std::string contents; - rv = ReadTransaction(trans.get(), &contents); - EXPECT_EQ(OK, rv); +TEST_F(HttpNetworkLayerTest, ServerOneProxyNoDirectBypassPac) { + ConfigureTestDependencies(ProxyService::CreateFixedFromPacResult( + "PROXY bad:8080")); + TestProxyFallbackFail(1u); +} - // We should obtain content from the second socket provider write - // corresponding to the fallback proxy. - EXPECT_EQ("content", contents); - // We also have a server header here that isn't set by the proxy. - EXPECT_TRUE(trans->GetResponseInfo()->headers->HasHeaderValue( - "server", "not-proxy")); - // We should also observe the bad proxy in the retry list. - ASSERT_TRUE(1u == proxy_service_->proxy_retry_info().size()); - EXPECT_EQ("bad:8080", (*proxy_service_->proxy_retry_info().begin()).first); +TEST_F(HttpNetworkLayerTest, ServerOneProxyNoDirectBypassFixed) { + ConfigureTestDependencies(ProxyService::CreateFixed("bad:8080")); + TestProxyFallbackFail(1u); } #if defined(SPDY_PROXY_AUTH_ORIGIN) @@ -274,60 +421,7 @@ TEST_F(HttpNetworkLayerTest, ServerFallbackOnInternalServerError) { } #endif // defined(SPDY_PROXY_AUTH_ORIGIN) -TEST_F(HttpNetworkLayerTest, ServerFallbackDoesntLoop) { - // Verify that a Connection: Proxy-Bypass header will display the original - // proxy's error page content if a fallback option is not configured. - ConfigureTestDependencies(ProxyService::CreateFixedFromPacResult( - "PROXY bad:8080; PROXY alsobad:8080")); - - MockRead data_reads[] = { - MockRead("HTTP/1.1 200 OK\r\n" - "Connection: proxy-bypass\r\n\r\n"), - MockRead("Bypass message"), - MockRead(SYNCHRONOUS, OK), - }; - MockWrite data_writes[] = { - MockWrite("GET http://www.google.com/ HTTP/1.1\r\n" - "Host: www.google.com\r\n" - "Proxy-Connection: keep-alive\r\n\r\n"), - }; - StaticSocketDataProvider data1(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes)); - StaticSocketDataProvider data2(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes)); - mock_socket_factory_.AddSocketDataProvider(&data1); - mock_socket_factory_.AddSocketDataProvider(&data2); - - TestCompletionCallback callback; - - HttpRequestInfo request_info; - request_info.url = GURL("http://www.google.com/"); - request_info.method = "GET"; - request_info.load_flags = LOAD_NORMAL; - - scoped_ptr<HttpTransaction> trans; - int rv = factory_->CreateTransaction(DEFAULT_PRIORITY, &trans, NULL); - EXPECT_EQ(OK, rv); - - rv = trans->Start(&request_info, callback.callback(), BoundNetLog()); - if (rv == ERR_IO_PENDING) - rv = callback.WaitForResult(); - ASSERT_EQ(OK, rv); - - std::string contents; - rv = ReadTransaction(trans.get(), &contents); - EXPECT_EQ(OK, rv); - EXPECT_EQ("Bypass message", contents); - - // Despite not falling back to anything, we should still observe the proxies - // in the bad proxies list. - const ProxyRetryInfoMap& retry_info = proxy_service_->proxy_retry_info(); - ASSERT_EQ(2u, retry_info.size()); - ASSERT_TRUE(retry_info.find("bad:8080") != retry_info.end()); - ASSERT_TRUE(retry_info.find("alsobad:8080") != retry_info.end()); -} - -TEST_F(HttpNetworkLayerTest, ProxyBypassIgnoredOnDirectConnection) { +TEST_F(HttpNetworkLayerTest, ProxyBypassIgnoredOnDirectConnectionPac) { // Verify that a Connection: proxy-bypass header is ignored when returned // from a directly connected origin server. ConfigureTestDependencies(ProxyService::CreateFixedFromPacResult("DIRECT")); diff --git a/chromium/net/http/http_network_session.cc b/chromium/net/http/http_network_session.cc index 32659629e90..346cbbc8941 100644 --- a/chromium/net/http/http_network_session.cc +++ b/chromium/net/http/http_network_session.cc @@ -107,6 +107,7 @@ HttpNetworkSession::HttpNetworkSession(const Params& params) params.client_socket_factory ? params.client_socket_factory : net::ClientSocketFactory::GetDefaultFactory(), + params.http_server_properties, params.quic_crypto_client_stream_factory, params.quic_random ? params.quic_random : QuicRandom::GetInstance(), diff --git a/chromium/net/http/http_network_transaction.cc b/chromium/net/http/http_network_transaction.cc index 70292be257f..ff158a84e77 100644 --- a/chromium/net/http/http_network_transaction.cc +++ b/chromium/net/http/http_network_transaction.cc @@ -419,8 +419,10 @@ bool HttpNetworkTransaction::GetLoadTimingInfo( void HttpNetworkTransaction::SetPriority(RequestPriority priority) { priority_ = priority; - // TODO(akalin): Plumb this through to |stream_request_| and - // |stream_|. + if (stream_request_) + stream_request_->SetPriority(priority); + if (stream_) + stream_->SetPriority(priority); } void HttpNetworkTransaction::OnStreamReady(const SSLConfig& used_ssl_config, @@ -730,12 +732,15 @@ int HttpNetworkTransaction::DoGenerateProxyAuthTokenComplete(int rv) { int HttpNetworkTransaction::DoGenerateServerAuthToken() { next_state_ = STATE_GENERATE_SERVER_AUTH_TOKEN_COMPLETE; HttpAuth::Target target = HttpAuth::AUTH_SERVER; - if (!auth_controllers_[target].get()) + if (!auth_controllers_[target].get()) { auth_controllers_[target] = new HttpAuthController(target, AuthURL(target), session_->http_auth_cache(), session_->http_auth_handler_factory()); + if (request_->load_flags & LOAD_DO_NOT_USE_EMBEDDED_IDENTITY) + auth_controllers_[target]->DisableEmbeddedIdentity(); + } if (!ShouldApplyServerAuth()) return OK; return auth_controllers_[target]->MaybeGenerateAuthToken(request_, @@ -912,10 +917,8 @@ int HttpNetworkTransaction::DoReadHeadersComplete(int result) { } DCHECK(response_.headers.get()); - // Server-induced fallback is supported only if this is a PAC configured - // proxy. See: http://crbug.com/143712 - if (response_.was_fetched_via_proxy && proxy_info_.did_use_pac_script() && - response_.headers.get() != NULL) { + // Server-induced fallback; see: http://crbug.com/143712 + if (response_.was_fetched_via_proxy && response_.headers.get() != NULL) { bool should_fallback = response_.headers->HasHeaderValue("connection", "proxy-bypass"); // Additionally, fallback if a 500 is returned via the data reduction proxy. diff --git a/chromium/net/http/http_network_transaction_unittest.cc b/chromium/net/http/http_network_transaction_unittest.cc index 5f8dac2f87d..a2165776c08 100644 --- a/chromium/net/http/http_network_transaction_unittest.cc +++ b/chromium/net/http/http_network_transaction_unittest.cc @@ -15,6 +15,7 @@ #include "base/files/file_path.h" #include "base/json/json_writer.h" #include "base/memory/scoped_ptr.h" +#include "base/memory/weak_ptr.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/test/test_file_util.h" @@ -45,6 +46,7 @@ #include "net/http/http_stream_factory.h" #include "net/http/http_transaction_unittest.h" #include "net/proxy/proxy_config_service_fixed.h" +#include "net/proxy/proxy_info.h" #include "net/proxy/proxy_resolver.h" #include "net/proxy/proxy_service.h" #include "net/socket/client_socket_factory.h" @@ -58,6 +60,7 @@ #include "net/spdy/spdy_session_pool.h" #include "net/spdy/spdy_test_util_common.h" #include "net/ssl/ssl_cert_request_info.h" +#include "net/ssl/ssl_config_service.h" #include "net/ssl/ssl_config_service_defaults.h" #include "net/ssl/ssl_info.h" #include "net/test/cert_test_util.h" @@ -92,6 +95,11 @@ int GetIdleSocketCountInSSLSocketPool(net::HttpNetworkSession* session) { net::HttpNetworkSession::NORMAL_SOCKET_POOL)->IdleSocketCount(); } +bool IsTransportSocketPoolStalled(net::HttpNetworkSession* session) { + return session->GetTransportSocketPool( + net::HttpNetworkSession::NORMAL_SOCKET_POOL)->IsStalled(); +} + // Takes in a Value created from a NetLogHttpResponseParameter, and returns // a JSONified list of headers as a single string. Uses single quotes instead // of double quotes for easier comparison. Returns false on failure. @@ -449,7 +457,7 @@ class CaptureGroupNameSocketPool : public ParentPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) {} virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) {} virtual void CloseIdleSockets() {} virtual int IdleSocketCount() const { @@ -2792,7 +2800,8 @@ TEST_P(HttpNetworkTransactionTest, HttpsProxySpdyConnectHttps) { new HttpNetworkTransaction(DEFAULT_PRIORITY, session.get())); // CONNECT to www.google.com:443 via SPDY - scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); // fetch https://www.google.com/ via HTTP const char get[] = "GET / HTTP/1.1\r\n" @@ -2878,7 +2887,8 @@ TEST_P(HttpNetworkTransactionTest, HttpsProxySpdyConnectSpdy) { new HttpNetworkTransaction(DEFAULT_PRIORITY, session.get())); // CONNECT to www.google.com:443 via SPDY - scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); // fetch https://www.google.com/ via SPDY const char* const kMyUrl = "https://www.google.com/"; scoped_ptr<SpdyFrame> get( @@ -2966,7 +2976,8 @@ TEST_P(HttpNetworkTransactionTest, HttpsProxySpdyConnectFailure) { new HttpNetworkTransaction(DEFAULT_PRIORITY, session.get())); // CONNECT to www.google.com:443 via SPDY - scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> get( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); @@ -3028,7 +3039,8 @@ TEST_P(HttpNetworkTransactionTest, request2.load_flags = 0; // CONNECT to www.google.com:443 via SPDY. - scoped_ptr<SpdyFrame> connect1(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> connect1(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> conn_resp1( spdy_util_.ConstructSpdyGetSynReply(NULL, 0, 1)); @@ -3186,7 +3198,8 @@ TEST_P(HttpNetworkTransactionTest, request2.load_flags = 0; // CONNECT to www.google.com:443 via SPDY. - scoped_ptr<SpdyFrame> connect1(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> connect1(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> conn_resp1( spdy_util_.ConstructSpdyGetSynReply(NULL, 0, 1)); @@ -3854,7 +3867,10 @@ TEST_P(HttpNetworkTransactionTest, NTLMAuth1) { HttpRequestInfo request; request.method = "GET"; request.url = GURL("http://172.22.68.17/kids/login.aspx"); - request.load_flags = 0; + + // Ensure load is not disrupted by flags which suppress behaviour specific + // to other auth schemes. + request.load_flags = LOAD_DO_NOT_USE_EMBEDDED_IDENTITY; HttpAuthHandlerNTLM::ScopedProcSetter proc_setter(MockGenerateRandom1, MockGetHostName); @@ -4821,6 +4837,86 @@ TEST_P(HttpNetworkTransactionTest, WrongAuthIdentityInURL) { base::MessageLoop::current()->RunUntilIdle(); } + +// Test the request-challenge-retry sequence for basic auth when there is a +// correct identity in the URL, but its use is being suppressed. The identity +// from the URL should never be used. +TEST_P(HttpNetworkTransactionTest, AuthIdentityInURLSuppressed) { + HttpRequestInfo request; + request.method = "GET"; + request.url = GURL("http://foo:bar@www.google.com/"); + request.load_flags = LOAD_DO_NOT_USE_EMBEDDED_IDENTITY; + + scoped_ptr<HttpTransaction> trans( + new HttpNetworkTransaction(DEFAULT_PRIORITY, + CreateSession(&session_deps_))); + + MockWrite data_writes1[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + + MockRead data_reads1[] = { + MockRead("HTTP/1.0 401 Unauthorized\r\n"), + MockRead("WWW-Authenticate: Basic realm=\"MyRealm1\"\r\n"), + MockRead("Content-Length: 10\r\n\r\n"), + MockRead(SYNCHRONOUS, ERR_FAILED), + }; + + // After the challenge above, the transaction will be restarted using the + // identity supplied by the user, not the one in the URL, to answer the + // challenge. + MockWrite data_writes3[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n" + "Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + }; + + MockRead data_reads3[] = { + MockRead("HTTP/1.0 200 OK\r\n"), + MockRead("Content-Length: 100\r\n\r\n"), + MockRead(SYNCHRONOUS, OK), + }; + + StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1), + data_writes1, arraysize(data_writes1)); + StaticSocketDataProvider data3(data_reads3, arraysize(data_reads3), + data_writes3, arraysize(data_writes3)); + session_deps_.socket_factory->AddSocketDataProvider(&data1); + session_deps_.socket_factory->AddSocketDataProvider(&data3); + + TestCompletionCallback callback1; + int rv = trans->Start(&request, callback1.callback(), BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback1.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_FALSE(trans->IsReadyToRestartForAuth()); + + const HttpResponseInfo* response = trans->GetResponseInfo(); + ASSERT_TRUE(response != NULL); + EXPECT_TRUE(CheckBasicServerAuth(response->auth_challenge.get())); + + TestCompletionCallback callback3; + rv = trans->RestartWithAuth( + AuthCredentials(kFoo, kBar), callback3.callback()); + EXPECT_EQ(ERR_IO_PENDING, rv); + rv = callback3.WaitForResult(); + EXPECT_EQ(OK, rv); + EXPECT_FALSE(trans->IsReadyToRestartForAuth()); + + response = trans->GetResponseInfo(); + ASSERT_TRUE(response != NULL); + + // There is no challenge info, since the identity worked. + EXPECT_TRUE(response->auth_challenge.get() == NULL); + EXPECT_EQ(100, response->headers->GetContentLength()); + + // Empty the current queue. + base::MessageLoop::current()->RunUntilIdle(); +} + // Test that previously tried username/passwords for a realm get re-used. TEST_P(HttpNetworkTransactionTest, BasicAuthCacheAndPreauth) { scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); @@ -5667,7 +5763,8 @@ TEST_P(HttpNetworkTransactionTest, RedirectOfHttpsConnectViaSpdyProxy) { request.url = GURL("https://www.google.com/"); request.load_flags = 0; - scoped_ptr<SpdyFrame> conn(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> conn(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> goaway( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); MockWrite data_writes[] = { @@ -5775,7 +5872,8 @@ TEST_P(HttpNetworkTransactionTest, request.url = GURL("https://www.google.com/"); request.load_flags = 0; - scoped_ptr<SpdyFrame> conn(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> conn(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> rst( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); MockWrite data_writes[] = { @@ -5841,7 +5939,8 @@ TEST_P(HttpNetworkTransactionTest, BasicAuthSpdyProxy) { scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); // Since we have proxy, should try to establish tunnel. - scoped_ptr<SpdyFrame> req(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> req(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> rst( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); @@ -5851,7 +5950,7 @@ TEST_P(HttpNetworkTransactionTest, BasicAuthSpdyProxy) { "proxy-authorization", "Basic Zm9vOmJhcg==", }; scoped_ptr<SpdyFrame> connect2(spdy_util_.ConstructSpdyConnect( - kAuthCredentials, arraysize(kAuthCredentials) / 2, 3)); + kAuthCredentials, arraysize(kAuthCredentials) / 2, 3, LOWEST)); // fetch https://www.google.com/ via HTTP const char get[] = "GET / HTTP/1.1\r\n" "Host: www.google.com\r\n" @@ -7271,8 +7370,12 @@ void HttpNetworkTransactionTest::BypassHostCacheOnRefreshHelper( AddressList addrlist; TestCompletionCallback callback; int rv = session_deps_.host_resolver->Resolve( - HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), &addrlist, - callback.callback(), NULL, BoundNetLog()); + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &addrlist, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -7280,8 +7383,12 @@ void HttpNetworkTransactionTest::BypassHostCacheOnRefreshHelper( // Verify that it was added to host cache, by doing a subsequent async lookup // and confirming it completes synchronously. rv = session_deps_.host_resolver->Resolve( - HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), &addrlist, - callback.callback(), NULL, BoundNetLog()); + HostResolver::RequestInfo(HostPortPair("www.google.com", 80)), + DEFAULT_PRIORITY, + &addrlist, + callback.callback(), + NULL, + BoundNetLog()); ASSERT_EQ(OK, rv); // Inject a failure the next time that "www.google.com" is resolved. This way @@ -10476,9 +10583,12 @@ WRAPPED_TEST_P(HttpNetworkTransactionTest, MAYBE_UseIPConnectionPooling) { HostPortPair host_port("www.gmail.com", 443); HostResolver::RequestInfo resolve_info(host_port); AddressList ignored; - rv = session_deps_.host_resolver->Resolve(resolve_info, &ignored, - callback.callback(), NULL, - BoundNetLog()); + rv = session_deps_.host_resolver->Resolve(resolve_info, + DEFAULT_PRIORITY, + &ignored, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -10602,12 +10712,13 @@ class OneTimeCachingHostResolver : public net::HostResolver { // HostResolver methods: virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, const BoundNetLog& net_log) OVERRIDE { return host_resolver_.Resolve( - info, addresses, callback, out_req, net_log); + info, priority, addresses, callback, out_req, net_log); } virtual int ResolveFromCache(const RequestInfo& info, @@ -10721,8 +10832,12 @@ WRAPPED_TEST_P(HttpNetworkTransactionTest, // Preload cache entries into HostCache. HostResolver::RequestInfo resolve_info(HostPortPair("www.gmail.com", 443)); AddressList ignored; - rv = host_resolver.Resolve(resolve_info, &ignored, callback.callback(), - NULL, BoundNetLog()); + rv = host_resolver.Resolve(resolve_info, + DEFAULT_PRIORITY, + &ignored, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); EXPECT_EQ(OK, rv); @@ -10881,7 +10996,8 @@ TEST_P(HttpNetworkTransactionTest, DoNotUseSpdySessionForHttpOverTunnel) { const std::string http_url = "http://www.google.com:443/"; // SPDY GET for HTTPS URL (through CONNECT tunnel) - scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + scoped_ptr<SpdyFrame> connect(spdy_util_.ConstructSpdyConnect(NULL, 0, 1, + LOWEST)); scoped_ptr<SpdyFrame> req1( spdy_util_.ConstructSpdyGet(https_url.c_str(), false, 1, LOWEST)); @@ -11680,4 +11796,463 @@ TEST_P(HttpNetworkTransactionTest, GetFullRequestHeadersIncludesExtraHeader) { EXPECT_EQ("bar", foo); } +namespace { + +// Fake HttpStreamBase that simply records calls to SetPriority(). +class FakeStream : public HttpStreamBase, + public base::SupportsWeakPtr<FakeStream> { + public: + explicit FakeStream(RequestPriority priority) : priority_(priority) {} + virtual ~FakeStream() {} + + RequestPriority priority() const { return priority_; } + + virtual int InitializeStream(const HttpRequestInfo* request_info, + RequestPriority priority, + const BoundNetLog& net_log, + const CompletionCallback& callback) OVERRIDE { + return ERR_IO_PENDING; + } + + virtual int SendRequest(const HttpRequestHeaders& request_headers, + HttpResponseInfo* response, + const CompletionCallback& callback) OVERRIDE { + ADD_FAILURE(); + return ERR_UNEXPECTED; + } + + virtual int ReadResponseHeaders(const CompletionCallback& callback) OVERRIDE { + ADD_FAILURE(); + return ERR_UNEXPECTED; + } + + virtual const HttpResponseInfo* GetResponseInfo() const OVERRIDE { + ADD_FAILURE(); + return NULL; + } + + virtual int ReadResponseBody(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE { + ADD_FAILURE(); + return ERR_UNEXPECTED; + } + + virtual void Close(bool not_reusable) OVERRIDE {} + + virtual bool IsResponseBodyComplete() const OVERRIDE { + ADD_FAILURE(); + return false; + } + + virtual bool CanFindEndOfResponse() const OVERRIDE { + return false; + } + + virtual bool IsConnectionReused() const OVERRIDE { + ADD_FAILURE(); + return false; + } + + virtual void SetConnectionReused() OVERRIDE { + ADD_FAILURE(); + } + + virtual bool IsConnectionReusable() const OVERRIDE { + ADD_FAILURE(); + return false; + } + + virtual bool GetLoadTimingInfo( + LoadTimingInfo* load_timing_info) const OVERRIDE { + ADD_FAILURE(); + return false; + } + + virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { + ADD_FAILURE(); + } + + virtual void GetSSLCertRequestInfo( + SSLCertRequestInfo* cert_request_info) OVERRIDE { + ADD_FAILURE(); + } + + virtual bool IsSpdyHttpStream() const OVERRIDE { + ADD_FAILURE(); + return false; + } + + virtual void Drain(HttpNetworkSession* session) OVERRIDE { + ADD_FAILURE(); + } + + virtual void SetPriority(RequestPriority priority) OVERRIDE { + priority_ = priority; + } + + private: + RequestPriority priority_; + + DISALLOW_COPY_AND_ASSIGN(FakeStream); +}; + +// Fake HttpStreamRequest that simply records calls to SetPriority() +// and vends FakeStreams with its current priority. +class FakeStreamRequest : public HttpStreamRequest, + public base::SupportsWeakPtr<FakeStreamRequest> { + public: + FakeStreamRequest(RequestPriority priority, + HttpStreamRequest::Delegate* delegate) + : priority_(priority), + delegate_(delegate) {} + + virtual ~FakeStreamRequest() {} + + RequestPriority priority() const { return priority_; } + + // Create a new FakeStream and pass it to the request's + // delegate. Returns a weak pointer to the FakeStream. + base::WeakPtr<FakeStream> FinishStreamRequest() { + FakeStream* fake_stream = new FakeStream(priority_); + // Do this before calling OnStreamReady() as OnStreamReady() may + // immediately delete |fake_stream|. + base::WeakPtr<FakeStream> weak_stream = fake_stream->AsWeakPtr(); + delegate_->OnStreamReady(SSLConfig(), ProxyInfo(), fake_stream); + return weak_stream; + } + + virtual int RestartTunnelWithProxyAuth( + const AuthCredentials& credentials) OVERRIDE { + ADD_FAILURE(); + return ERR_UNEXPECTED; + } + + virtual LoadState GetLoadState() const OVERRIDE { + ADD_FAILURE(); + return LoadState(); + } + + virtual void SetPriority(RequestPriority priority) OVERRIDE { + priority_ = priority; + } + + virtual bool was_npn_negotiated() const OVERRIDE { + return false; + } + + virtual NextProto protocol_negotiated() const OVERRIDE { + return kProtoUnknown; + } + + virtual bool using_spdy() const OVERRIDE { + return false; + } + + private: + RequestPriority priority_; + HttpStreamRequest::Delegate* const delegate_; + + DISALLOW_COPY_AND_ASSIGN(FakeStreamRequest); +}; + +// Fake HttpStreamFactory that vends FakeStreamRequests. +class FakeStreamFactory : public HttpStreamFactory { + public: + FakeStreamFactory() {} + virtual ~FakeStreamFactory() {} + + // Returns a WeakPtr<> to the last HttpStreamRequest returned by + // RequestStream() (which may be NULL if it was destroyed already). + base::WeakPtr<FakeStreamRequest> last_stream_request() { + return last_stream_request_; + } + + virtual HttpStreamRequest* RequestStream( + const HttpRequestInfo& info, + RequestPriority priority, + const SSLConfig& server_ssl_config, + const SSLConfig& proxy_ssl_config, + HttpStreamRequest::Delegate* delegate, + const BoundNetLog& net_log) OVERRIDE { + FakeStreamRequest* fake_request = new FakeStreamRequest(priority, delegate); + last_stream_request_ = fake_request->AsWeakPtr(); + return fake_request; + } + + virtual HttpStreamRequest* RequestWebSocketStream( + const HttpRequestInfo& info, + RequestPriority priority, + const SSLConfig& server_ssl_config, + const SSLConfig& proxy_ssl_config, + HttpStreamRequest::Delegate* delegate, + WebSocketStreamBase::Factory* factory, + const BoundNetLog& net_log) OVERRIDE { + ADD_FAILURE(); + return NULL; + } + + virtual void PreconnectStreams(int num_streams, + const HttpRequestInfo& info, + RequestPriority priority, + const SSLConfig& server_ssl_config, + const SSLConfig& proxy_ssl_config) OVERRIDE { + ADD_FAILURE(); + } + + virtual base::Value* PipelineInfoToValue() const OVERRIDE { + ADD_FAILURE(); + return NULL; + } + + virtual const HostMappingRules* GetHostMappingRules() const OVERRIDE { + ADD_FAILURE(); + return NULL; + } + + private: + base::WeakPtr<FakeStreamRequest> last_stream_request_; + + DISALLOW_COPY_AND_ASSIGN(FakeStreamFactory); +}; + +} // namespace + +// Make sure that HttpNetworkTransaction passes on its priority to its +// stream request on start. +TEST_P(HttpNetworkTransactionTest, SetStreamRequestPriorityOnStart) { + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); + HttpNetworkSessionPeer peer(session); + FakeStreamFactory* fake_factory = new FakeStreamFactory(); + peer.SetHttpStreamFactory(fake_factory); + + HttpNetworkTransaction trans(LOW, session); + + ASSERT_TRUE(fake_factory->last_stream_request() == NULL); + + HttpRequestInfo request; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + trans.Start(&request, callback.callback(), BoundNetLog())); + + base::WeakPtr<FakeStreamRequest> fake_request = + fake_factory->last_stream_request(); + ASSERT_TRUE(fake_request != NULL); + EXPECT_EQ(LOW, fake_request->priority()); +} + +// Make sure that HttpNetworkTransaction passes on its priority +// updates to its stream request. +TEST_P(HttpNetworkTransactionTest, SetStreamRequestPriority) { + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); + HttpNetworkSessionPeer peer(session); + FakeStreamFactory* fake_factory = new FakeStreamFactory(); + peer.SetHttpStreamFactory(fake_factory); + + HttpNetworkTransaction trans(LOW, session); + + HttpRequestInfo request; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + trans.Start(&request, callback.callback(), BoundNetLog())); + + base::WeakPtr<FakeStreamRequest> fake_request = + fake_factory->last_stream_request(); + ASSERT_TRUE(fake_request != NULL); + EXPECT_EQ(LOW, fake_request->priority()); + + trans.SetPriority(LOWEST); + ASSERT_TRUE(fake_request != NULL); + EXPECT_EQ(LOWEST, fake_request->priority()); +} + +// Make sure that HttpNetworkTransaction passes on its priority +// updates to its stream. +TEST_P(HttpNetworkTransactionTest, SetStreamPriority) { + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); + HttpNetworkSessionPeer peer(session); + FakeStreamFactory* fake_factory = new FakeStreamFactory(); + peer.SetHttpStreamFactory(fake_factory); + + HttpNetworkTransaction trans(LOW, session); + + HttpRequestInfo request; + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + trans.Start(&request, callback.callback(), BoundNetLog())); + + base::WeakPtr<FakeStreamRequest> fake_request = + fake_factory->last_stream_request(); + ASSERT_TRUE(fake_request != NULL); + base::WeakPtr<FakeStream> fake_stream = fake_request->FinishStreamRequest(); + ASSERT_TRUE(fake_stream != NULL); + EXPECT_EQ(LOW, fake_stream->priority()); + + trans.SetPriority(LOWEST); + EXPECT_EQ(LOWEST, fake_stream->priority()); +} + +// Tests that when a used socket is returned to the SSL socket pool, it's closed +// if the transport socket pool is stalled on the global socket limit. +TEST_P(HttpNetworkTransactionTest, CloseSSLSocketOnIdleForHttpRequest) { + ClientSocketPoolManager::set_max_sockets_per_group( + HttpNetworkSession::NORMAL_SOCKET_POOL, 1); + ClientSocketPoolManager::set_max_sockets_per_pool( + HttpNetworkSession::NORMAL_SOCKET_POOL, 1); + + // Set up SSL request. + + HttpRequestInfo ssl_request; + ssl_request.method = "GET"; + ssl_request.url = GURL("https://www.google.com/"); + + MockWrite ssl_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + MockRead ssl_reads[] = { + MockRead("HTTP/1.1 200 OK\r\n"), + MockRead("Content-Length: 11\r\n\r\n"), + MockRead("hello world"), + MockRead(SYNCHRONOUS, OK), + }; + StaticSocketDataProvider ssl_data(ssl_reads, arraysize(ssl_reads), + ssl_writes, arraysize(ssl_writes)); + session_deps_.socket_factory->AddSocketDataProvider(&ssl_data); + + SSLSocketDataProvider ssl(ASYNC, OK); + session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl); + + // Set up HTTP request. + + HttpRequestInfo http_request; + http_request.method = "GET"; + http_request.url = GURL("http://www.google.com/"); + + MockWrite http_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + MockRead http_reads[] = { + MockRead("HTTP/1.1 200 OK\r\n"), + MockRead("Content-Length: 7\r\n\r\n"), + MockRead("falafel"), + MockRead(SYNCHRONOUS, OK), + }; + StaticSocketDataProvider http_data(http_reads, arraysize(http_reads), + http_writes, arraysize(http_writes)); + session_deps_.socket_factory->AddSocketDataProvider(&http_data); + + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); + + // Start the SSL request. + TestCompletionCallback ssl_callback; + scoped_ptr<HttpTransaction> ssl_trans( + new HttpNetworkTransaction(DEFAULT_PRIORITY, session.get())); + ASSERT_EQ(ERR_IO_PENDING, + ssl_trans->Start(&ssl_request, ssl_callback.callback(), + BoundNetLog())); + + // Start the HTTP request. Pool should stall. + TestCompletionCallback http_callback; + scoped_ptr<HttpTransaction> http_trans( + new HttpNetworkTransaction(DEFAULT_PRIORITY, session.get())); + ASSERT_EQ(ERR_IO_PENDING, + http_trans->Start(&http_request, http_callback.callback(), + BoundNetLog())); + EXPECT_TRUE(IsTransportSocketPoolStalled(session)); + + // Wait for response from SSL request. + ASSERT_EQ(OK, ssl_callback.WaitForResult()); + std::string response_data; + ASSERT_EQ(OK, ReadTransaction(ssl_trans.get(), &response_data)); + EXPECT_EQ("hello world", response_data); + + // The SSL socket should automatically be closed, so the HTTP request can + // start. + EXPECT_EQ(0, GetIdleSocketCountInSSLSocketPool(session)); + ASSERT_FALSE(IsTransportSocketPoolStalled(session)); + + // The HTTP request can now complete. + ASSERT_EQ(OK, http_callback.WaitForResult()); + ASSERT_EQ(OK, ReadTransaction(http_trans.get(), &response_data)); + EXPECT_EQ("falafel", response_data); + + EXPECT_EQ(1, GetIdleSocketCountInTransportSocketPool(session)); +} + +// Tests that when a SSL connection is established but there's no corresponding +// request that needs it, the new socket is closed if the transport socket pool +// is stalled on the global socket limit. +TEST_P(HttpNetworkTransactionTest, CloseSSLSocketOnIdleForHttpRequest2) { + ClientSocketPoolManager::set_max_sockets_per_group( + HttpNetworkSession::NORMAL_SOCKET_POOL, 1); + ClientSocketPoolManager::set_max_sockets_per_pool( + HttpNetworkSession::NORMAL_SOCKET_POOL, 1); + + // Set up an ssl request. + + HttpRequestInfo ssl_request; + ssl_request.method = "GET"; + ssl_request.url = GURL("https://www.foopy.com/"); + + // No data will be sent on the SSL socket. + StaticSocketDataProvider ssl_data; + session_deps_.socket_factory->AddSocketDataProvider(&ssl_data); + + SSLSocketDataProvider ssl(ASYNC, OK); + session_deps_.socket_factory->AddSSLSocketDataProvider(&ssl); + + // Set up HTTP request. + + HttpRequestInfo http_request; + http_request.method = "GET"; + http_request.url = GURL("http://www.google.com/"); + + MockWrite http_writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Connection: keep-alive\r\n\r\n"), + }; + MockRead http_reads[] = { + MockRead("HTTP/1.1 200 OK\r\n"), + MockRead("Content-Length: 7\r\n\r\n"), + MockRead("falafel"), + MockRead(SYNCHRONOUS, OK), + }; + StaticSocketDataProvider http_data(http_reads, arraysize(http_reads), + http_writes, arraysize(http_writes)); + session_deps_.socket_factory->AddSocketDataProvider(&http_data); + + scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps_)); + + // Preconnect an SSL socket. A preconnect is needed because connect jobs are + // cancelled when a normal transaction is cancelled. + net::HttpStreamFactory* http_stream_factory = session->http_stream_factory(); + net::SSLConfig ssl_config; + session->ssl_config_service()->GetSSLConfig(&ssl_config); + http_stream_factory->PreconnectStreams(1, ssl_request, DEFAULT_PRIORITY, + ssl_config, ssl_config); + EXPECT_EQ(0, GetIdleSocketCountInSSLSocketPool(session)); + + // Start the HTTP request. Pool should stall. + TestCompletionCallback http_callback; + scoped_ptr<HttpTransaction> http_trans( + new HttpNetworkTransaction(DEFAULT_PRIORITY, session.get())); + ASSERT_EQ(ERR_IO_PENDING, + http_trans->Start(&http_request, http_callback.callback(), + BoundNetLog())); + EXPECT_TRUE(IsTransportSocketPoolStalled(session)); + + // The SSL connection will automatically be closed once the connection is + // established, to let the HTTP request start. + ASSERT_EQ(OK, http_callback.WaitForResult()); + std::string response_data; + ASSERT_EQ(OK, ReadTransaction(http_trans.get(), &response_data)); + EXPECT_EQ("falafel", response_data); + + EXPECT_EQ(1, GetIdleSocketCountInTransportSocketPool(session)); +} + } // namespace net diff --git a/chromium/net/http/http_pipelined_connection_impl_unittest.cc b/chromium/net/http/http_pipelined_connection_impl_unittest.cc index 1bf7597e33d..296194ecd63 100644 --- a/chromium/net/http/http_pipelined_connection_impl_unittest.cc +++ b/chromium/net/http/http_pipelined_connection_impl_unittest.cc @@ -28,15 +28,6 @@ using testing::StrEq; namespace net { -class DummySocketParams : public base::RefCounted<DummySocketParams> { - private: - friend class base::RefCounted<DummySocketParams>; - ~DummySocketParams() {} -}; - -REGISTER_SOCKET_PARAMS_FOR_POOL(MockTransportClientSocketPool, - DummySocketParams); - namespace { // Tests the load timing of a stream that's connected and is not the first @@ -118,7 +109,7 @@ class HttpPipelinedConnectionImplTest : public testing::Test { data_->StopAfter(reads_count + writes_count); } factory_.AddSocketDataProvider(data_.get()); - scoped_refptr<DummySocketParams> params; + scoped_refptr<MockTransportSocketParams> params; ClientSocketHandle* connection = new ClientSocketHandle; // Only give the connection a real NetLog to make sure that LoadTiming uses // the connection's ID, rather than the pipeline's. Since pipelines are diff --git a/chromium/net/http/http_pipelined_host_forced.cc b/chromium/net/http/http_pipelined_host_forced.cc index 8179e86f319..8059d848d73 100644 --- a/chromium/net/http/http_pipelined_host_forced.cc +++ b/chromium/net/http/http_pipelined_host_forced.cc @@ -36,10 +36,9 @@ HttpPipelinedStream* HttpPipelinedHostForced::CreateStreamOnNewPipeline( bool was_npn_negotiated, NextProto protocol_negotiated) { CHECK(!pipeline_.get()); - StreamSocket* wrapped_socket = connection->release_socket(); - BufferedWriteStreamSocket* buffered_socket = new BufferedWriteStreamSocket( - wrapped_socket); - connection->set_socket(buffered_socket); + scoped_ptr<BufferedWriteStreamSocket> buffered_socket( + new BufferedWriteStreamSocket(connection->PassSocket())); + connection->SetSocket(buffered_socket.PassAs<StreamSocket>()); pipeline_.reset(factory_->CreateNewPipeline( connection, this, key_.origin(), used_ssl_config, used_proxy_info, net_log, was_npn_negotiated, protocol_negotiated)); diff --git a/chromium/net/http/http_pipelined_stream.cc b/chromium/net/http/http_pipelined_stream.cc index 951c2f9afd6..df5743556d1 100644 --- a/chromium/net/http/http_pipelined_stream.cc +++ b/chromium/net/http/http_pipelined_stream.cc @@ -121,6 +121,11 @@ void HttpPipelinedStream::Drain(HttpNetworkSession* session) { pipeline_->Drain(this, session); } +void HttpPipelinedStream::SetPriority(RequestPriority priority) { + // TODO(akalin): Plumb this through to |pipeline_| and its + // underlying ClientSocketHandle. +} + const SSLConfig& HttpPipelinedStream::used_ssl_config() const { return pipeline_->used_ssl_config(); } diff --git a/chromium/net/http/http_pipelined_stream.h b/chromium/net/http/http_pipelined_stream.h index 675d8f083a2..d3a7991e5ca 100644 --- a/chromium/net/http/http_pipelined_stream.h +++ b/chromium/net/http/http_pipelined_stream.h @@ -81,6 +81,8 @@ class HttpPipelinedStream : public HttpStream { virtual void Drain(HttpNetworkSession* session) OVERRIDE; + virtual void SetPriority(RequestPriority priority) OVERRIDE; + // The SSLConfig used to establish this stream's pipeline. const SSLConfig& used_ssl_config() const; diff --git a/chromium/net/http/http_proxy_client_socket_pool.cc b/chromium/net/http/http_proxy_client_socket_pool.cc index b80df37b3dd..8d691b7dd80 100644 --- a/chromium/net/http/http_proxy_client_socket_pool.cc +++ b/chromium/net/http/http_proxy_client_socket_pool.cc @@ -58,7 +58,7 @@ HttpProxySocketParams::HttpProxySocketParams( const HostResolver::RequestInfo& HttpProxySocketParams::destination() const { if (transport_params_.get() == NULL) { - return ssl_params_->transport_params()->destination(); + return ssl_params_->GetDirectConnectionParams()->destination(); } else { return transport_params_->destination(); } @@ -77,6 +77,7 @@ static const int kHttpProxyConnectJobTimeoutInSeconds = 30; HttpProxyConnectJob::HttpProxyConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<HttpProxySocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -84,7 +85,7 @@ HttpProxyConnectJob::HttpProxyConnectJob( HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), weak_ptr_factory_(this), params_(params), @@ -179,10 +180,12 @@ int HttpProxyConnectJob::DoLoop(int result) { int HttpProxyConnectJob::DoTransportConnect() { next_state_ = STATE_TCP_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - return transport_socket_handle_->Init( - group_name(), params_->transport_params(), - params_->transport_params()->destination().priority(), callback_, - transport_pool_, net_log()); + return transport_socket_handle_->Init(group_name(), + params_->transport_params(), + priority(), + callback_, + transport_pool_, + net_log()); } int HttpProxyConnectJob::DoTransportConnectComplete(int result) { @@ -213,9 +216,8 @@ int HttpProxyConnectJob::DoSSLConnect() { next_state_ = STATE_SSL_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); return transport_socket_handle_->Init( - group_name(), params_->ssl_params(), - params_->ssl_params()->transport_params()->destination().priority(), - callback_, ssl_pool_, net_log()); + group_name(), params_->ssl_params(), priority(), callback_, + ssl_pool_, net_log()); } int HttpProxyConnectJob::DoSSLConnectComplete(int result) { @@ -289,7 +291,7 @@ int HttpProxyConnectJob::DoHttpProxyConnect() { int HttpProxyConnectJob::DoHttpProxyConnectComplete(int result) { if (result == OK || result == ERR_PROXY_AUTH_REQUESTED || result == ERR_HTTPS_PROXY_TUNNEL_RESPONSE) { - set_socket(transport_socket_.release()); + SetSocket(transport_socket_.PassAs<StreamSocket>()); } return result; @@ -321,9 +323,12 @@ int HttpProxyConnectJob::DoSpdyProxyCreateStream() { } next_state_ = STATE_SPDY_PROXY_CREATE_STREAM_COMPLETE; - return spdy_stream_request_.StartRequest( - SPDY_BIDIRECTIONAL_STREAM, spdy_session, params_->request_url(), - params_->destination().priority(), spdy_session->net_log(), callback_); + return spdy_stream_request_.StartRequest(SPDY_BIDIRECTIONAL_STREAM, + spdy_session, + params_->request_url(), + priority(), + spdy_session->net_log(), + callback_); } int HttpProxyConnectJob::DoSpdyProxyCreateStreamComplete(int result) { @@ -380,19 +385,20 @@ HttpProxyConnectJobFactory::HttpProxyConnectJobFactory( } -ConnectJob* +scoped_ptr<ConnectJob> HttpProxyClientSocketPool::HttpProxyConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new HttpProxyConnectJob(group_name, - request.params(), - ConnectionTimeout(), - transport_pool_, - ssl_pool_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>(new HttpProxyConnectJob(group_name, + request.priority(), + request.params(), + ConnectionTimeout(), + transport_pool_, + ssl_pool_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -411,7 +417,7 @@ HttpProxyClientSocketPool::HttpProxyClientSocketPool( NetLog* net_log) : transport_pool_(transport_pool), ssl_pool_(ssl_pool), - base_(max_sockets, max_sockets_per_group, histograms, + base_(this, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new HttpProxyConnectJobFactory(transport_pool, @@ -420,17 +426,12 @@ HttpProxyClientSocketPool::HttpProxyClientSocketPool( net_log)) { // We should always have a |transport_pool_| except in unit tests. if (transport_pool_) - transport_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(transport_pool_); if (ssl_pool_) - ssl_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(ssl_pool_); } HttpProxyClientSocketPool::~HttpProxyClientSocketPool() { - if (ssl_pool_) - ssl_pool_->RemoveLayeredPool(this); - // We should always have a |transport_pool_| except in unit tests. - if (transport_pool_) - transport_pool_->RemoveLayeredPool(this); } int HttpProxyClientSocketPool::RequestSocket( @@ -462,20 +463,15 @@ void HttpProxyClientSocketPool::CancelRequest( } void HttpProxyClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void HttpProxyClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool HttpProxyClientSocketPool::IsStalled() const { - return base_.IsStalled() || - (transport_pool_ && transport_pool_->IsStalled()) || - (ssl_pool_ && ssl_pool_->IsStalled()); -} - void HttpProxyClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -494,14 +490,6 @@ LoadState HttpProxyClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void HttpProxyClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void HttpProxyClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* HttpProxyClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -532,10 +520,24 @@ ClientSocketPoolHistograms* HttpProxyClientSocketPool::histograms() const { return base_.histograms(); } +bool HttpProxyClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void HttpProxyClientSocketPool::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void HttpProxyClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); +} + bool HttpProxyClientSocketPool::CloseOneIdleConnection() { if (base_.CloseOneIdleSocket()) return true; - return base_.CloseOneIdleConnectionInLayeredPool(); + return base_.CloseOneIdleConnectionInHigherLayeredPool(); } } // namespace net diff --git a/chromium/net/http/http_proxy_client_socket_pool.h b/chromium/net/http/http_proxy_client_socket_pool.h index a15b8cad809..a26c05f6603 100644 --- a/chromium/net/http/http_proxy_client_socket_pool.h +++ b/chromium/net/http/http_proxy_client_socket_pool.h @@ -96,6 +96,7 @@ class NET_EXPORT_PRIVATE HttpProxySocketParams class HttpProxyConnectJob : public ConnectJob { public: HttpProxyConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<HttpProxySocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -174,8 +175,10 @@ class HttpProxyConnectJob : public ConnectJob { class NET_EXPORT_PRIVATE HttpProxyClientSocketPool : public ClientSocketPool, - public LayeredPool { + public HigherLayeredPool { public: + typedef HttpProxySocketParams SocketParams; + HttpProxyClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -204,13 +207,11 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -222,10 +223,6 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -235,7 +232,14 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; - // LayeredPool implementation. + // LowerLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + // HigherLayeredPool implementation. virtual bool CloseOneIdleConnection() OVERRIDE; private: @@ -250,7 +254,7 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool NetLog* net_log); // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -274,9 +278,6 @@ class NET_EXPORT_PRIVATE HttpProxyClientSocketPool DISALLOW_COPY_AND_ASSIGN(HttpProxyClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(HttpProxyClientSocketPool, - HttpProxySocketParams); - } // namespace net #endif // NET_HTTP_HTTP_PROXY_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/http/http_proxy_client_socket_pool_unittest.cc b/chromium/net/http/http_proxy_client_socket_pool_unittest.cc index 2274f250ff1..808305240ad 100644 --- a/chromium/net/http/http_proxy_client_socket_pool_unittest.cc +++ b/chromium/net/http/http_proxy_client_socket_pool_unittest.cc @@ -55,31 +55,14 @@ struct HttpProxyClientSocketPoolTestParams { typedef ::testing::TestWithParam<HttpProxyType> TestWithHttpParam; -} // namespace +const char kHttpProxyHost[] = "httpproxy.example.com"; +const char kHttpsProxyHost[] = "httpsproxy.example.com"; class HttpProxyClientSocketPoolTest : public ::testing::TestWithParam<HttpProxyClientSocketPoolTestParams> { protected: HttpProxyClientSocketPoolTest() : session_deps_(GetParam().protocol), - ssl_config_(), - ignored_transport_socket_params_( - new TransportSocketParams(HostPortPair("proxy", 80), - LOWEST, - false, - false, - OnHostResolutionCallback())), - ignored_ssl_socket_params_( - new SSLSocketParams(ignored_transport_socket_params_, - NULL, - NULL, - ProxyServer::SCHEME_DIRECT, - HostPortPair("www.google.com", 443), - ssl_config_, - kPrivacyModeDisabled, - 0, - false, - false)), tcp_histograms_("MockTCP"), transport_socket_pool_( kMaxSockets, @@ -118,7 +101,9 @@ class HttpProxyClientSocketPoolTest void AddAuthToCache() { const base::string16 kFoo(ASCIIToUTF16("foo")); const base::string16 kBar(ASCIIToUTF16("bar")); - GURL proxy_url(GetParam().proxy_type == HTTP ? "http://proxy" : "https://proxy:80"); + GURL proxy_url(GetParam().proxy_type == HTTP ? + (std::string("http://") + kHttpProxyHost) : + (std::string("https://") + kHttpsProxyHost)); session_->http_auth_cache()->Add(proxy_url, "MyRealm1", HttpAuth::AUTH_SCHEME_BASIC, @@ -127,24 +112,40 @@ class HttpProxyClientSocketPoolTest "/"); } - scoped_refptr<TransportSocketParams> GetTcpParams() { + scoped_refptr<TransportSocketParams> CreateHttpProxyParams() const { if (GetParam().proxy_type != HTTP) - return scoped_refptr<TransportSocketParams>(); - return ignored_transport_socket_params_; + return NULL; + return new TransportSocketParams(HostPortPair(kHttpProxyHost, 80), + false, + false, + OnHostResolutionCallback()); } - scoped_refptr<SSLSocketParams> GetSslParams() { + scoped_refptr<SSLSocketParams> CreateHttpsProxyParams() const { if (GetParam().proxy_type == HTTP) - return scoped_refptr<SSLSocketParams>(); - return ignored_ssl_socket_params_; + return NULL; + return new SSLSocketParams( + new TransportSocketParams( + HostPortPair(kHttpsProxyHost, 443), + false, + false, + OnHostResolutionCallback()), + NULL, + NULL, + HostPortPair(kHttpsProxyHost, 443), + SSLConfig(), + kPrivacyModeDisabled, + 0, + false, + false); } // Returns the a correctly constructed HttpProxyParms // for the HTTP or HTTPS proxy. - scoped_refptr<HttpProxySocketParams> GetParams(bool tunnel) { + scoped_refptr<HttpProxySocketParams> CreateParams(bool tunnel) { return scoped_refptr<HttpProxySocketParams>(new HttpProxySocketParams( - GetTcpParams(), - GetSslParams(), + CreateHttpProxyParams(), + CreateHttpsProxyParams(), GURL(tunnel ? "https://www.google.com/" : "http://www.google.com"), std::string(), HostPortPair("www.google.com", tunnel ? 443 : 80), @@ -154,16 +155,16 @@ class HttpProxyClientSocketPoolTest tunnel)); } - scoped_refptr<HttpProxySocketParams> GetTunnelParams() { - return GetParams(true); + scoped_refptr<HttpProxySocketParams> CreateTunnelParams() { + return CreateParams(true); } - scoped_refptr<HttpProxySocketParams> GetNoTunnelParams() { - return GetParams(false); + scoped_refptr<HttpProxySocketParams> CreateNoTunnelParams() { + return CreateParams(false); } - DeterministicMockClientSocketFactory& socket_factory() { - return *session_deps_.deterministic_socket_factory.get(); + DeterministicMockClientSocketFactory* socket_factory() { + return session_deps_.deterministic_socket_factory.get(); } void Initialize(MockRead* reads, size_t reads_count, @@ -181,14 +182,14 @@ class HttpProxyClientSocketPoolTest data_->set_connect_data(MockConnect(SYNCHRONOUS, OK)); data_->StopAfter(2); // Request / Response - socket_factory().AddSocketDataProvider(data_.get()); + socket_factory()->AddSocketDataProvider(data_.get()); if (GetParam().proxy_type != HTTP) { ssl_data_.reset(new SSLSocketDataProvider(SYNCHRONOUS, OK)); if (GetParam().proxy_type == SPDY) { InitializeSpdySsl(); } - socket_factory().AddSSLSocketDataProvider(ssl_data_.get()); + socket_factory()->AddSSLSocketDataProvider(ssl_data_.get()); } } @@ -201,12 +202,13 @@ class HttpProxyClientSocketPoolTest &session_deps_); } + RequestPriority GetLastTransportRequestPriority() const { + return transport_socket_pool_.last_request_priority(); + } + private: SpdySessionDependencies session_deps_; - SSLConfig ssl_config_; - scoped_refptr<TransportSocketParams> ignored_transport_socket_params_; - scoped_refptr<SSLSocketParams> ignored_ssl_socket_params_; ClientSocketPoolHistograms tcp_histograms_; MockTransportClientSocketPool transport_socket_pool_; ClientSocketPoolHistograms ssl_histograms_; @@ -255,7 +257,7 @@ INSTANTIATE_TEST_CASE_P( TEST_P(HttpProxyClientSocketPoolTest, NoTunnel) { Initialize(NULL, 0, NULL, 0, NULL, 0, NULL, 0); - int rv = handle_.Init("a", GetNoTunnelParams(), LOW, CompletionCallback(), + int rv = handle_.Init("a", CreateNoTunnelParams(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle_.is_initialized()); @@ -265,6 +267,16 @@ TEST_P(HttpProxyClientSocketPoolTest, NoTunnel) { EXPECT_TRUE(tunnel_socket->IsConnected()); } +// Make sure that HttpProxyConnectJob passes on its priority to its +// (non-SSL) socket request on Init. +TEST_P(HttpProxyClientSocketPoolTest, SetSocketRequestPriorityOnInit) { + Initialize(NULL, 0, NULL, 0, NULL, 0, NULL, 0); + EXPECT_EQ(OK, + handle_.Init("a", CreateNoTunnelParams(), HIGHEST, + CompletionCallback(), &pool_, BoundNetLog())); + EXPECT_EQ(HIGHEST, GetLastTransportRequestPriority()); +} + TEST_P(HttpProxyClientSocketPoolTest, NeedAuth) { MockWrite writes[] = { MockWrite(ASYNC, 0, "CONNECT www.google.com:443 HTTP/1.1\r\n" @@ -279,7 +291,7 @@ TEST_P(HttpProxyClientSocketPoolTest, NeedAuth) { MockRead(ASYNC, 4, "0123456789"), }; scoped_ptr<SpdyFrame> req( - spdy_util_.ConstructSpdyConnect(NULL, 0, 1)); + spdy_util_.ConstructSpdyConnect(NULL, 0, 1, LOW)); scoped_ptr<SpdyFrame> rst( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); MockWrite spdy_writes[] = { @@ -296,7 +308,7 @@ TEST_P(HttpProxyClientSocketPoolTest, NeedAuth) { 0, false, 1, - LOWEST, + LOW, SYN_REPLY, CONTROL_FLAG_NONE, kAuthChallenge, @@ -312,7 +324,7 @@ TEST_P(HttpProxyClientSocketPoolTest, NeedAuth) { arraysize(spdy_writes)); data_->StopAfter(4); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -331,7 +343,6 @@ TEST_P(HttpProxyClientSocketPoolTest, NeedAuth) { } else { EXPECT_FALSE(tunnel_socket->IsConnected()); EXPECT_FALSE(tunnel_socket->IsUsingSpdy()); - EXPECT_FALSE(tunnel_socket->IsUsingSpdy()); } } @@ -355,7 +366,7 @@ TEST_P(HttpProxyClientSocketPoolTest, HaveAuth) { NULL, 0); AddAuthToCache(); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle_.is_initialized()); @@ -377,7 +388,7 @@ TEST_P(HttpProxyClientSocketPoolTest, AsyncHaveAuth) { }; scoped_ptr<SpdyFrame> req( - spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1)); + spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1, LOW)); MockWrite spdy_writes[] = { CreateMockWrite(*req, 0, ASYNC) }; @@ -392,7 +403,7 @@ TEST_P(HttpProxyClientSocketPoolTest, AsyncHaveAuth) { arraysize(spdy_writes)); AddAuthToCache(); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -407,14 +418,47 @@ TEST_P(HttpProxyClientSocketPoolTest, AsyncHaveAuth) { EXPECT_TRUE(tunnel_socket->IsConnected()); } +// Make sure that HttpProxyConnectJob passes on its priority to its +// SPDY session's socket request on Init (if applicable). +TEST_P(HttpProxyClientSocketPoolTest, + SetSpdySessionSocketRequestPriorityOnInit) { + if (GetParam().proxy_type != SPDY) + return; + + scoped_ptr<SpdyFrame> req( + spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, + 1, MEDIUM)); + MockWrite spdy_writes[] = { + CreateMockWrite(*req, 0, ASYNC) + }; + scoped_ptr<SpdyFrame> resp(spdy_util_.ConstructSpdyGetSynReply(NULL, 0, 1)); + MockRead spdy_reads[] = { + CreateMockRead(*resp, 1, ASYNC), + MockRead(ASYNC, 0, 2) + }; + + Initialize(NULL, 0, NULL, 0, + spdy_reads, arraysize(spdy_reads), + spdy_writes, arraysize(spdy_writes)); + AddAuthToCache(); + + EXPECT_EQ(ERR_IO_PENDING, + handle_.Init("a", CreateTunnelParams(), MEDIUM, + callback_.callback(), &pool_, BoundNetLog())); + EXPECT_EQ(MEDIUM, GetLastTransportRequestPriority()); + + data_->RunFor(2); + EXPECT_EQ(OK, callback_.WaitForResult()); +} + TEST_P(HttpProxyClientSocketPoolTest, TCPError) { if (GetParam().proxy_type == SPDY) return; data_.reset(new DeterministicSocketData(NULL, 0, NULL, 0)); data_->set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_CLOSED)); - socket_factory().AddSocketDataProvider(data_.get()); + socket_factory()->AddSocketDataProvider(data_.get()); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -430,16 +474,16 @@ TEST_P(HttpProxyClientSocketPoolTest, SSLError) { if (GetParam().proxy_type == HTTP) return; data_.reset(new DeterministicSocketData(NULL, 0, NULL, 0)); data_->set_connect_data(MockConnect(ASYNC, OK)); - socket_factory().AddSocketDataProvider(data_.get()); + socket_factory()->AddSocketDataProvider(data_.get()); ssl_data_.reset(new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); if (GetParam().proxy_type == SPDY) { InitializeSpdySsl(); } - socket_factory().AddSSLSocketDataProvider(ssl_data_.get()); + socket_factory()->AddSSLSocketDataProvider(ssl_data_.get()); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -455,16 +499,16 @@ TEST_P(HttpProxyClientSocketPoolTest, SslClientAuth) { if (GetParam().proxy_type == HTTP) return; data_.reset(new DeterministicSocketData(NULL, 0, NULL, 0)); data_->set_connect_data(MockConnect(ASYNC, OK)); - socket_factory().AddSocketDataProvider(data_.get()); + socket_factory()->AddSocketDataProvider(data_.get()); ssl_data_.reset(new SSLSocketDataProvider(ASYNC, ERR_SSL_CLIENT_AUTH_CERT_NEEDED)); if (GetParam().proxy_type == SPDY) { InitializeSpdySsl(); } - socket_factory().AddSSLSocketDataProvider(ssl_data_.get()); + socket_factory()->AddSSLSocketDataProvider(ssl_data_.get()); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -489,7 +533,7 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelUnexpectedClose) { MockRead(ASYNC, ERR_CONNECTION_CLOSED, 2), }; scoped_ptr<SpdyFrame> req( - spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1)); + spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1, LOW)); MockWrite spdy_writes[] = { CreateMockWrite(*req, 0, ASYNC) }; @@ -502,7 +546,7 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelUnexpectedClose) { arraysize(spdy_writes)); AddAuthToCache(); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -520,6 +564,37 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelUnexpectedClose) { EXPECT_FALSE(handle_.socket()); } +TEST_P(HttpProxyClientSocketPoolTest, Tunnel1xxResponse) { + // Tests that 1xx responses are rejected for a CONNECT request. + if (GetParam().proxy_type == SPDY) { + // SPDY doesn't have 1xx responses. + return; + } + + MockWrite writes[] = { + MockWrite(ASYNC, 0, + "CONNECT www.google.com:443 HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "Proxy-Connection: keep-alive\r\n\r\n"), + }; + MockRead reads[] = { + MockRead(ASYNC, 1, "HTTP/1.1 100 Continue\r\n\r\n"), + MockRead(ASYNC, 2, "HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + + Initialize(reads, arraysize(reads), writes, arraysize(writes), + NULL, 0, NULL, 0); + + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), + &pool_, BoundNetLog()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_FALSE(handle_.is_initialized()); + EXPECT_FALSE(handle_.socket()); + + data_->RunFor(2); + EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, callback_.WaitForResult()); +} + TEST_P(HttpProxyClientSocketPoolTest, TunnelSetupError) { MockWrite writes[] = { MockWrite(ASYNC, 0, @@ -532,7 +607,7 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelSetupError) { MockRead(ASYNC, 1, "HTTP/1.1 304 Not Modified\r\n\r\n"), }; scoped_ptr<SpdyFrame> req( - spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1)); + spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1, LOW)); scoped_ptr<SpdyFrame> rst( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); MockWrite spdy_writes[] = { @@ -550,7 +625,7 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelSetupError) { arraysize(spdy_writes)); AddAuthToCache(); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -583,7 +658,7 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelSetupRedirect) { MockRead(ASYNC, 1, responseText.c_str()), }; scoped_ptr<SpdyFrame> req( - spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1)); + spdy_util_.ConstructSpdyConnect(kAuthHeaders, kAuthHeadersSize, 1, LOW)); scoped_ptr<SpdyFrame> rst( spdy_util_.ConstructSpdyRstStream(1, RST_STREAM_CANCEL)); @@ -612,7 +687,7 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelSetupRedirect) { arraysize(spdy_writes)); AddAuthToCache(); - int rv = handle_.Init("a", GetTunnelParams(), LOW, callback_.callback(), + int rv = handle_.Init("a", CreateTunnelParams(), LOW, callback_.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle_.is_initialized()); @@ -653,4 +728,6 @@ TEST_P(HttpProxyClientSocketPoolTest, TunnelSetupRedirect) { // It would be nice to also test the timeouts in HttpProxyClientSocketPool. +} // namespace + } // namespace net diff --git a/chromium/net/http/http_response_body_drainer_unittest.cc b/chromium/net/http/http_response_body_drainer_unittest.cc index 5d9fcc4689e..70134cce1ea 100644 --- a/chromium/net/http/http_response_body_drainer_unittest.cc +++ b/chromium/net/http/http_response_body_drainer_unittest.cc @@ -127,6 +127,8 @@ class MockHttpStream : public HttpStream { virtual void Drain(HttpNetworkSession*) OVERRIDE {} + virtual void SetPriority(RequestPriority priority) OVERRIDE {} + // Methods to tweak/observer mock behavior: void set_stall_reads_forever() { stall_reads_forever_ = true; } diff --git a/chromium/net/http/http_stream_base.h b/chromium/net/http/http_stream_base.h index 6dce6d2a805..596ed75dff1 100644 --- a/chromium/net/http/http_stream_base.h +++ b/chromium/net/http/http_stream_base.h @@ -141,6 +141,9 @@ class NET_EXPORT_PRIVATE HttpStreamBase { // draining is complete. virtual void Drain(HttpNetworkSession* session) = 0; + // Called when the priority of the parent transaction changes. + virtual void SetPriority(RequestPriority priority) = 0; + private: DISALLOW_COPY_AND_ASSIGN(HttpStreamBase); }; diff --git a/chromium/net/http/http_stream_factory.h b/chromium/net/http/http_stream_factory.h index 6db6905b433..0de3b65bc57 100644 --- a/chromium/net/http/http_stream_factory.h +++ b/chromium/net/http/http_stream_factory.h @@ -157,6 +157,9 @@ class NET_EXPORT_PRIVATE HttpStreamRequest { virtual int RestartTunnelWithProxyAuth( const AuthCredentials& credentials) = 0; + // Called when the priority of the parent transaction changes. + virtual void SetPriority(RequestPriority priority) = 0; + // Returns the LoadState for the request. virtual LoadState GetLoadState() const = 0; diff --git a/chromium/net/http/http_stream_factory_impl.h b/chromium/net/http/http_stream_factory_impl.h index 3949f3839ee..4339fd350d7 100644 --- a/chromium/net/http/http_stream_factory_impl.h +++ b/chromium/net/http/http_stream_factory_impl.h @@ -9,6 +9,7 @@ #include <set> #include <vector> +#include "base/gtest_prod_util.h" #include "base/memory/ref_counted.h" #include "net/base/host_port_pair.h" #include "net/base/net_log.h" @@ -66,8 +67,10 @@ class NET_EXPORT_PRIVATE HttpStreamFactoryImpl : size_t num_orphaned_jobs() const { return orphaned_job_set_.size(); } private: - class Request; - class Job; + FRIEND_TEST_ALL_PREFIXES(HttpStreamFactoryImplRequestTest, SetPriority); + + class NET_EXPORT_PRIVATE Request; + class NET_EXPORT_PRIVATE Job; typedef std::set<Request*> RequestSet; typedef std::vector<Request*> RequestVector; diff --git a/chromium/net/http/http_stream_factory_impl_job.cc b/chromium/net/http/http_stream_factory_impl_job.cc index b2eee3b0fbe..c0383f4772d 100644 --- a/chromium/net/http/http_stream_factory_impl_job.cc +++ b/chromium/net/http/http_stream_factory_impl_job.cc @@ -230,12 +230,17 @@ void HttpStreamFactoryImpl::Job::Orphan(const Request* request) { } } +void HttpStreamFactoryImpl::Job::SetPriority(RequestPriority priority) { + priority_ = priority; + // TODO(akalin): Propagate this to |connection_| and maybe the + // preconnect state. +} + bool HttpStreamFactoryImpl::Job::was_npn_negotiated() const { return was_npn_negotiated_; } -NextProto HttpStreamFactoryImpl::Job::protocol_negotiated() - const { +NextProto HttpStreamFactoryImpl::Job::protocol_negotiated() const { return protocol_negotiated_; } diff --git a/chromium/net/http/http_stream_factory_impl_job.h b/chromium/net/http/http_stream_factory_impl_job.h index 2c2eb349586..01a794a1bc7 100644 --- a/chromium/net/http/http_stream_factory_impl_job.h +++ b/chromium/net/http/http_stream_factory_impl_job.h @@ -74,6 +74,9 @@ class HttpStreamFactoryImpl::Job { // Used to detach the Job from |request|. void Orphan(const Request* request); + void SetPriority(RequestPriority priority); + + RequestPriority priority() const { return priority_; } bool was_npn_negotiated() const; NextProto protocol_negotiated() const; bool using_spdy() const; diff --git a/chromium/net/http/http_stream_factory_impl_request.cc b/chromium/net/http/http_stream_factory_impl_request.cc index e73a897a528..57190ed72e7 100644 --- a/chromium/net/http/http_stream_factory_impl_request.cc +++ b/chromium/net/http/http_stream_factory_impl_request.cc @@ -215,6 +215,15 @@ int HttpStreamFactoryImpl::Request::RestartTunnelWithProxyAuth( return bound_job_->RestartTunnelWithProxyAuth(credentials); } +void HttpStreamFactoryImpl::Request::SetPriority(RequestPriority priority) { + for (std::set<HttpStreamFactoryImpl::Job*>::const_iterator it = jobs_.begin(); + it != jobs_.end(); ++it) { + (*it)->SetPriority(priority); + } + if (bound_job_) + bound_job_->SetPriority(priority); +} + LoadState HttpStreamFactoryImpl::Request::GetLoadState() const { if (bound_job_.get()) return bound_job_->GetLoadState(); diff --git a/chromium/net/http/http_stream_factory_impl_request.h b/chromium/net/http/http_stream_factory_impl_request.h index 169e1f54ce9..d6f9b02cbef 100644 --- a/chromium/net/http/http_stream_factory_impl_request.h +++ b/chromium/net/http/http_stream_factory_impl_request.h @@ -105,6 +105,7 @@ class HttpStreamFactoryImpl::Request : public HttpStreamRequest { virtual int RestartTunnelWithProxyAuth( const AuthCredentials& credentials) OVERRIDE; + virtual void SetPriority(RequestPriority priority) OVERRIDE; virtual LoadState GetLoadState() const OVERRIDE; virtual bool was_npn_negotiated() const OVERRIDE; virtual NextProto protocol_negotiated() const OVERRIDE; diff --git a/chromium/net/http/http_stream_factory_impl_request_unittest.cc b/chromium/net/http/http_stream_factory_impl_request_unittest.cc new file mode 100644 index 00000000000..1f38a2e56f6 --- /dev/null +++ b/chromium/net/http/http_stream_factory_impl_request_unittest.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/http/http_stream_factory_impl_request.h" + +#include "net/http/http_stream_factory_impl_job.h" +#include "net/proxy/proxy_info.h" +#include "net/proxy/proxy_service.h" +#include "net/spdy/spdy_test_util_common.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +class HttpStreamFactoryImplRequestTest + : public ::testing::Test, + public ::testing::WithParamInterface<NextProto> {}; + +INSTANTIATE_TEST_CASE_P( + NextProto, + HttpStreamFactoryImplRequestTest, + testing::Values(kProtoSPDY2, kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2, + kProtoHTTP2Draft04)); + +namespace { + +class DoNothingRequestDelegate : public HttpStreamRequest::Delegate { + public: + DoNothingRequestDelegate() {} + + virtual ~DoNothingRequestDelegate() {} + + // HttpStreamRequest::Delegate + virtual void OnStreamReady( + const SSLConfig& used_ssl_config, + const ProxyInfo& used_proxy_info, + HttpStreamBase* stream) OVERRIDE {} + virtual void OnWebSocketStreamReady( + const SSLConfig& used_ssl_config, + const ProxyInfo& used_proxy_info, + WebSocketStreamBase* stream) OVERRIDE {} + virtual void OnStreamFailed( + int status, + const SSLConfig& used_ssl_config) OVERRIDE {} + virtual void OnCertificateError( + int status, + const SSLConfig& used_ssl_config, + const SSLInfo& ssl_info) OVERRIDE {} + virtual void OnNeedsProxyAuth(const HttpResponseInfo& proxy_response, + const SSLConfig& used_ssl_config, + const ProxyInfo& used_proxy_info, + HttpAuthController* auth_controller) OVERRIDE {} + virtual void OnNeedsClientAuth(const SSLConfig& used_ssl_config, + SSLCertRequestInfo* cert_info) OVERRIDE {} + virtual void OnHttpsProxyTunnelResponse(const HttpResponseInfo& response_info, + const SSLConfig& used_ssl_config, + const ProxyInfo& used_proxy_info, + HttpStreamBase* stream) OVERRIDE {} +}; + +} // namespace + +// Make sure that Request passes on its priority updates to its jobs. +TEST_P(HttpStreamFactoryImplRequestTest, SetPriority) { + SpdySessionDependencies session_deps(GetParam(), + ProxyService::CreateDirect()); + + scoped_refptr<HttpNetworkSession> + session(SpdySessionDependencies::SpdyCreateSession(&session_deps)); + HttpStreamFactoryImpl* factory = + static_cast<HttpStreamFactoryImpl*>(session->http_stream_factory()); + + DoNothingRequestDelegate request_delegate; + HttpStreamFactoryImpl::Request request( + GURL(), factory, &request_delegate, NULL, BoundNetLog()); + + HttpStreamFactoryImpl::Job* job = + new HttpStreamFactoryImpl::Job(factory, + session, + HttpRequestInfo(), + DEFAULT_PRIORITY, + SSLConfig(), + SSLConfig(), + NULL); + request.AttachJob(job); + EXPECT_EQ(DEFAULT_PRIORITY, job->priority()); + + request.SetPriority(MEDIUM); + EXPECT_EQ(MEDIUM, job->priority()); + + // Make |job| the bound job. + request.OnStreamFailed(job, ERR_FAILED, SSLConfig()); + + request.SetPriority(IDLE); + EXPECT_EQ(IDLE, job->priority()); +} + +} // namespace net diff --git a/chromium/net/http/http_stream_factory_impl_unittest.cc b/chromium/net/http/http_stream_factory_impl_unittest.cc index 14fbc0338a3..f378c93ea18 100644 --- a/chromium/net/http/http_stream_factory_impl_unittest.cc +++ b/chromium/net/http/http_stream_factory_impl_unittest.cc @@ -314,7 +314,7 @@ class CapturePreconnectsSocketPool : public ParentPool { ADD_FAILURE(); } virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE { ADD_FAILURE(); } diff --git a/chromium/net/http/http_stream_parser.cc b/chromium/net/http/http_stream_parser.cc index e53aafdac1c..853414a106a 100644 --- a/chromium/net/http/http_stream_parser.cc +++ b/chromium/net/http/http_stream_parser.cc @@ -79,7 +79,7 @@ namespace net { // Example: // // scoped_refptr<SeekableIOBuffer> buf = new SeekableIOBuffer(1024); -// // capacity() == 1024. size() == BytesRemaining == BytesConsumed() == 0. +// // capacity() == 1024. size() == BytesRemaining() == BytesConsumed() == 0. // // data() points to the beginning of the buffer. // // // Read() takes an IOBuffer. @@ -94,7 +94,7 @@ namespace net { // buf->DidConsume(bytes_written); // } // // BytesRemaining() == 0. BytesConsumed() == size(). -// // data() points to the end of the comsumed bytes (exclusive). +// // data() points to the end of the consumed bytes (exclusive). // // // If you want to reuse the buffer, be sure to clear the buffer. // buf->Clear(); @@ -161,7 +161,7 @@ class HttpStreamParser::SeekableIOBuffer : public net::IOBuffer { } char* real_data_; - int capacity_; + const int capacity_; int size_; int used_; }; @@ -293,6 +293,7 @@ int HttpStreamParser::ReadResponseHeaders(const CompletionCallback& callback) { DCHECK(io_state_ == STATE_REQUEST_SENT || io_state_ == STATE_DONE); DCHECK(callback_.is_null()); DCHECK(!callback.is_null()); + DCHECK_EQ(0, read_buf_unused_offset_); // This function can be called with io_state_ == STATE_DONE if the // connection is closed after seeing just a 1xx response code. @@ -304,8 +305,8 @@ int HttpStreamParser::ReadResponseHeaders(const CompletionCallback& callback) { if (read_buf_->offset() > 0) { // Simulate the state where the data was just read from the socket. - result = read_buf_->offset() - read_buf_unused_offset_; - read_buf_->set_offset(read_buf_unused_offset_); + result = read_buf_->offset(); + read_buf_->set_offset(0); } if (result > 0) io_state_ = STATE_READ_HEADERS_COMPLETE; @@ -517,6 +518,8 @@ int HttpStreamParser::DoReadHeaders() { } int HttpStreamParser::DoReadHeadersComplete(int result) { + DCHECK_EQ(0, read_buf_unused_offset_); + if (result == 0) result = ERR_CONNECTION_CLOSED; @@ -580,41 +583,43 @@ int HttpStreamParser::DoReadHeadersComplete(int result) { if (end_of_header_offset == -1) { io_state_ = STATE_READ_HEADERS; // Prevent growing the headers buffer indefinitely. - if (read_buf_->offset() - read_buf_unused_offset_ >= kMaxHeaderBufSize) { + if (read_buf_->offset() >= kMaxHeaderBufSize) { io_state_ = STATE_DONE; return ERR_RESPONSE_HEADERS_TOO_BIG; } } else { - // Note where the headers stop. - read_buf_unused_offset_ = end_of_header_offset; - - if (response_->headers->response_code() / 100 == 1) { - // After processing a 1xx response, the caller will ask for the next - // header, so reset state to support that. We don't just skip these - // completely because 1xx codes aren't acceptable when establishing a - // tunnel. - io_state_ = STATE_REQUEST_SENT; - response_header_start_offset_ = -1; - } else { - io_state_ = STATE_BODY_PENDING; - CalculateResponseBodySize(); - // If the body is 0, the caller may not call ReadResponseBody, which - // is where any extra data is copied to read_buf_, so we move the - // data here and transition to DONE. - if (response_body_length_ == 0) { + CalculateResponseBodySize(); + // If the body is zero length, the caller may not call ReadResponseBody, + // which is where any extra data is copied to read_buf_, so we move the + // data here. + if (response_body_length_ == 0) { + int extra_bytes = read_buf_->offset() - end_of_header_offset; + if (extra_bytes) { + CHECK_GT(extra_bytes, 0); + memmove(read_buf_->StartOfBuffer(), + read_buf_->StartOfBuffer() + end_of_header_offset, + extra_bytes); + } + read_buf_->SetCapacity(extra_bytes); + if (response_->headers->response_code() / 100 == 1) { + // After processing a 1xx response, the caller will ask for the next + // header, so reset state to support that. We don't completely ignore a + // 1xx response because it cannot be returned in reply to a CONNECT + // request so we return OK here, which lets the caller inspect the + // response and reject it in the event that we're setting up a CONNECT + // tunnel. + response_header_start_offset_ = -1; + response_body_length_ = -1; + io_state_ = STATE_REQUEST_SENT; + } else { io_state_ = STATE_DONE; - int extra_bytes = read_buf_->offset() - read_buf_unused_offset_; - if (extra_bytes) { - CHECK_GT(extra_bytes, 0); - memmove(read_buf_->StartOfBuffer(), - read_buf_->StartOfBuffer() + read_buf_unused_offset_, - extra_bytes); - } - read_buf_->SetCapacity(extra_bytes); - read_buf_unused_offset_ = 0; - return OK; } + return OK; } + + // Note where the headers stop. + read_buf_unused_offset_ = end_of_header_offset; + io_state_ = STATE_BODY_PENDING; } return result; } @@ -751,20 +756,19 @@ int HttpStreamParser::DoReadBodyComplete(int result) { int HttpStreamParser::ParseResponseHeaders() { int end_offset = -1; + DCHECK_EQ(0, read_buf_unused_offset_); // Look for the start of the status line, if it hasn't been found yet. if (response_header_start_offset_ < 0) { response_header_start_offset_ = HttpUtil::LocateStartOfStatusLine( - read_buf_->StartOfBuffer() + read_buf_unused_offset_, - read_buf_->offset() - read_buf_unused_offset_); + read_buf_->StartOfBuffer(), read_buf_->offset()); } if (response_header_start_offset_ >= 0) { - end_offset = HttpUtil::LocateEndOfHeaders( - read_buf_->StartOfBuffer() + read_buf_unused_offset_, - read_buf_->offset() - read_buf_unused_offset_, - response_header_start_offset_); - } else if (read_buf_->offset() - read_buf_unused_offset_ >= 8) { + end_offset = HttpUtil::LocateEndOfHeaders(read_buf_->StartOfBuffer(), + read_buf_->offset(), + response_header_start_offset_); + } else if (read_buf_->offset() >= 8) { // Enough data to decide that this is an HTTP/0.9 response. // 8 bytes = (4 bytes of junk) + "http".length() end_offset = 0; @@ -776,14 +780,16 @@ int HttpStreamParser::ParseResponseHeaders() { int rv = DoParseResponseHeaders(end_offset); if (rv < 0) return rv; - return end_offset + read_buf_unused_offset_; + return end_offset; } int HttpStreamParser::DoParseResponseHeaders(int end_offset) { scoped_refptr<HttpResponseHeaders> headers; + DCHECK_EQ(0, read_buf_unused_offset_); + if (response_header_start_offset_ >= 0) { headers = new HttpResponseHeaders(HttpUtil::AssembleRawHeaders( - read_buf_->StartOfBuffer() + read_buf_unused_offset_, end_offset)); + read_buf_->StartOfBuffer(), end_offset)); } else { // Enough data was read -- there is no status line. headers = new HttpResponseHeaders(std::string("HTTP/0.9 200 OK")); @@ -830,13 +836,16 @@ void HttpStreamParser::CalculateResponseBodySize() { // (informational), 204 (no content), and 304 (not modified) responses // MUST NOT include a message-body. All other responses do include a // message-body, although it MAY be of zero length. - switch (response_->headers->response_code()) { - // Note that 1xx was already handled earlier. - case 204: // No Content - case 205: // Reset Content - case 304: // Not Modified - response_body_length_ = 0; - break; + if (response_->headers->response_code() / 100 == 1) { + response_body_length_ = 0; + } else { + switch (response_->headers->response_code()) { + case 204: // No Content + case 205: // Reset Content + case 304: // Not Modified + response_body_length_ = 0; + break; + } } if (request_->method == "HEAD") response_body_length_ = 0; diff --git a/chromium/net/http/http_stream_parser.h b/chromium/net/http/http_stream_parser.h index a41e393b485..43e1514c2fe 100644 --- a/chromium/net/http/http_stream_parser.h +++ b/chromium/net/http/http_stream_parser.h @@ -178,7 +178,8 @@ class NET_EXPORT_PRIVATE HttpStreamParser { scoped_refptr<GrowableIOBuffer> read_buf_; // Offset of the first unused byte in |read_buf_|. May be nonzero due to - // a 1xx header, or body data in the same packet as header data. + // body data in the same packet as header data but is zero when reading + // headers. int read_buf_unused_offset_; // The amount beyond |read_buf_unused_offset_| where the status line starts; diff --git a/chromium/net/http/http_stream_parser_unittest.cc b/chromium/net/http/http_stream_parser_unittest.cc index d530c2deab7..6e0053e6255 100644 --- a/chromium/net/http/http_stream_parser_unittest.cc +++ b/chromium/net/http/http_stream_parser_unittest.cc @@ -8,6 +8,7 @@ #include "base/files/file_path.h" #include "base/files/scoped_temp_dir.h" #include "base/memory/ref_counted.h" +#include "base/run_loop.h" #include "base/strings/string_piece.h" #include "base/strings/stringprintf.h" #include "net/base/io_buffer.h" @@ -116,29 +117,34 @@ TEST(HttpStreamParser, ShouldMergeRequestHeadersAndBody_ChunkedBody) { } TEST(HttpStreamParser, ShouldMergeRequestHeadersAndBody_FileBody) { - ScopedVector<UploadElementReader> element_readers; - - // Create an empty temporary file. - base::ScopedTempDir temp_dir; - ASSERT_TRUE(temp_dir.CreateUniqueTempDir()); - base::FilePath temp_file_path; - ASSERT_TRUE(file_util::CreateTemporaryFileInDir(temp_dir.path(), - &temp_file_path)); - - element_readers.push_back( - new UploadFileElementReader(base::MessageLoopProxy::current().get(), - temp_file_path, - 0, - 0, - base::Time())); - - scoped_ptr<UploadDataStream> body(new UploadDataStream(&element_readers, 0)); - TestCompletionCallback callback; - ASSERT_EQ(ERR_IO_PENDING, body->Init(callback.callback())); - ASSERT_EQ(OK, callback.WaitForResult()); - // Shouldn't be merged if upload data carries a file, as it's not in-memory. - ASSERT_FALSE(HttpStreamParser::ShouldMergeRequestHeadersAndBody( - "some header", body.get())); + { + ScopedVector<UploadElementReader> element_readers; + + // Create an empty temporary file. + base::ScopedTempDir temp_dir; + ASSERT_TRUE(temp_dir.CreateUniqueTempDir()); + base::FilePath temp_file_path; + ASSERT_TRUE(file_util::CreateTemporaryFileInDir(temp_dir.path(), + &temp_file_path)); + + element_readers.push_back( + new UploadFileElementReader(base::MessageLoopProxy::current().get(), + temp_file_path, + 0, + 0, + base::Time())); + + scoped_ptr<UploadDataStream> body( + new UploadDataStream(&element_readers, 0)); + TestCompletionCallback callback; + ASSERT_EQ(ERR_IO_PENDING, body->Init(callback.callback())); + ASSERT_EQ(OK, callback.WaitForResult()); + // Shouldn't be merged if upload data carries a file, as it's not in-memory. + ASSERT_FALSE(HttpStreamParser::ShouldMergeRequestHeadersAndBody( + "some header", body.get())); + } + // UploadFileElementReaders may post clean-up tasks on destruction. + base::RunLoop().RunUntilIdle(); } TEST(HttpStreamParser, ShouldMergeRequestHeadersAndBody_SmallBodyInMemory) { @@ -220,7 +226,7 @@ TEST(HttpStreamParser, AsyncChunkAndAsyncSocket) { ASSERT_EQ(OK, rv); scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); - socket_handle->set_socket(transport.release()); + socket_handle->SetSocket(transport.PassAs<StreamSocket>()); HttpRequestInfo request_info; request_info.method = "GET"; @@ -375,7 +381,7 @@ TEST(HttpStreamParser, TruncatedHeaders) { ASSERT_EQ(OK, rv); scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); - socket_handle->set_socket(transport.release()); + socket_handle->SetSocket(transport.PassAs<StreamSocket>()); HttpRequestInfo request_info; request_info.method = "GET"; diff --git a/chromium/net/http/http_transaction.h b/chromium/net/http/http_transaction.h index ec3fc088a03..d44050080ba 100644 --- a/chromium/net/http/http_transaction.h +++ b/chromium/net/http/http_transaction.h @@ -109,6 +109,9 @@ class NET_EXPORT_PRIVATE HttpTransaction { // of the stream. This is equivalent to performing an extra Read() at the end // that should return 0 bytes. This method should not be called if the // transaction is busy processing a previous operation (like a pending Read). + // + // DoneReading may also be called before the first Read() to notify that the + // entire response body is to be ignored (e.g., in a redirect). virtual void DoneReading() = 0; // Returns the response info for this transaction or NULL if the response diff --git a/chromium/net/http/http_util_icu.cc b/chromium/net/http/http_util_icu.cc index 64e7424df14..4f38f84d75b 100644 --- a/chromium/net/http/http_util_icu.cc +++ b/chromium/net/http/http_util_icu.cc @@ -14,7 +14,7 @@ namespace net { // static std::string HttpUtil::PathForRequest(const GURL& url) { - DCHECK(url.is_valid() && (url.SchemeIs("http") || url.SchemeIs("https"))); + DCHECK(url.is_valid() && url.SchemeIsHTTPOrHTTPS()); if (url.has_query()) return url.path() + "?" + url.query(); return url.path(); @@ -23,8 +23,7 @@ std::string HttpUtil::PathForRequest(const GURL& url) { // static std::string HttpUtil::SpecForRequest(const GURL& url) { // We may get ftp scheme when fetching ftp resources through proxy. - DCHECK(url.is_valid() && (url.SchemeIs("http") || - url.SchemeIs("https") || + DCHECK(url.is_valid() && (url.SchemeIsHTTPOrHTTPS() || url.SchemeIs("ftp"))); return SimplifyUrlForRequest(url).spec(); } diff --git a/chromium/net/http/proxy_connect_redirect_http_stream.cc b/chromium/net/http/proxy_connect_redirect_http_stream.cc index f30f33c002c..59bb0146953 100644 --- a/chromium/net/http/proxy_connect_redirect_http_stream.cc +++ b/chromium/net/http/proxy_connect_redirect_http_stream.cc @@ -109,6 +109,10 @@ void ProxyConnectRedirectHttpStream::Drain(HttpNetworkSession* session) { NOTREACHED(); } +void ProxyConnectRedirectHttpStream::SetPriority(RequestPriority priority) { + // Nothing to do. +} + UploadProgress ProxyConnectRedirectHttpStream::GetUploadProgress() const { NOTREACHED(); return UploadProgress(); diff --git a/chromium/net/http/proxy_connect_redirect_http_stream.h b/chromium/net/http/proxy_connect_redirect_http_stream.h index f39ec76b9b9..c335c218c71 100644 --- a/chromium/net/http/proxy_connect_redirect_http_stream.h +++ b/chromium/net/http/proxy_connect_redirect_http_stream.h @@ -59,6 +59,10 @@ class ProxyConnectRedirectHttpStream : public HttpStream { SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual bool IsSpdyHttpStream() const OVERRIDE; virtual void Drain(HttpNetworkSession* session) OVERRIDE; + + // This function may be called. + virtual void SetPriority(RequestPriority priority) OVERRIDE; + virtual UploadProgress GetUploadProgress() const OVERRIDE; virtual HttpStream* RenewStreamForAuth() OVERRIDE; diff --git a/chromium/net/http/transport_security_state_static.certs b/chromium/net/http/transport_security_state_static.certs index 610e56a103e..ca6ef535bc8 100644 --- a/chromium/net/http/transport_security_state_static.certs +++ b/chromium/net/http/transport_security_state_static.certs @@ -85,6 +85,65 @@ sha1/vq7OyjSnqOco9nyMCDGdy77eijM= GoogleG2 sha1/Q9rWMO5T+KmAym79hfRqo3mQ4Oo= +ThawteSGCCA +-----BEGIN CERTIFICATE----- +MIIDIzCCAoygAwIBAgIEMAAAAjANBgkqhkiG9w0BAQUFADBfMQswCQYDVQQGEwJV +UzEXMBUGA1UEChMOVmVyaVNpZ24sIEluYy4xNzA1BgNVBAsTLkNsYXNzIDMgUHVi +bGljIFByaW1hcnkgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkwHhcNMDQwNTEzMDAw +MDAwWhcNMTQwNTEyMjM1OTU5WjBMMQswCQYDVQQGEwJaQTElMCMGA1UEChMcVGhh +d3RlIENvbnN1bHRpbmcgKFB0eSkgTHRkLjEWMBQGA1UEAxMNVGhhd3RlIFNHQyBD +QTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1NNn0I0Vf67NMf59HZGhPwtx +PKzMyGT7Y/wySweUvW+Aui/hBJPAM/wJMyPpC3QrccQDxtLN4i/1CWPN/0ilAL/g +5/OIty0y3pg25gqtAHvEZEo7hHUD8nCSfQ5i9SGraTaEMXWQ+L/HbIgbBpV8yeWo +3nWhLHpo39XKHIdYYBkCAwEAAaOB/jCB+zASBgNVHRMBAf8ECDAGAQH/AgEAMAsG +A1UdDwQEAwIBBjARBglghkgBhvhCAQEEBAMCAQYwKAYDVR0RBCEwH6QdMBsxGTAX +BgNVBAMTEFByaXZhdGVMYWJlbDMtMTUwMQYDVR0fBCowKDAmoCSgIoYgaHR0cDov +L2NybC52ZXJpc2lnbi5jb20vcGNhMy5jcmwwMgYIKwYBBQUHAQEEJjAkMCIGCCsG +AQUFBzABhhZodHRwOi8vb2NzcC50aGF3dGUuY29tMDQGA1UdJQQtMCsGCCsGAQUF +BwMBBggrBgEFBQcDAgYJYIZIAYb4QgQBBgpghkgBhvhFAQgBMA0GCSqGSIb3DQEB +BQUAA4GBAFWsY+reod3SkF+fC852vhNRj5PZBSvIG3dLrWlQoe7e3P3bB+noOZTc +q3J5Lwa/q4FwxKjt6lM07e8eU9kGx1Yr0Vz00YqOtCuxN5BICEIlxT6Ky3/rbwTR +bcV0oveifHtgPHfNDs5IAn8BL7abN+AqKjbc1YXWrOU/VG+WHgWv +-----END CERTIFICATE----- + +VeriSignClass3SSPIntermediateCA +-----BEGIN CERTIFICATE----- +MIIGVDCCBTygAwIBAgIQGYH0QFTS4OtUK7v7RciQfjANBgkqhkiG9w0BAQUFADCB +yjELMAkGA1UEBhMCVVMxFzAVBgNVBAoTDlZlcmlTaWduLCBJbmMuMR8wHQYDVQQL +ExZWZXJpU2lnbiBUcnVzdCBOZXR3b3JrMTowOAYDVQQLEzEoYykgMTk5OSBWZXJp +U2lnbiwgSW5jLiAtIEZvciBhdXRob3JpemVkIHVzZSBvbmx5MUUwQwYDVQQDEzxW +ZXJpU2lnbiBDbGFzcyAzIFB1YmxpYyBQcmltYXJ5IENlcnRpZmljYXRpb24gQXV0 +aG9yaXR5IC0gRzMwHhcNMTEwMTA3MDAwMDAwWhcNMTMxMjMxMjM1OTU5WjB2MQsw +CQYDVQQGEwJVUzEXMBUGA1UEChMOVmVyaVNpZ24sIEluYy4xHzAdBgNVBAsTFlZl +cmlTaWduIFRydXN0IE5ldHdvcmsxLTArBgNVBAMTJFZlcmlTaWduIENsYXNzIDMg +U1NQIEludGVybWVkaWF0ZSBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBANfMaBonchSI7reVYNNe3hhSwUY/fbEmnDwCoonR2MFXsQkP9n8yNaU1nhRT +Eovg4zAetI+e0bDAt9/0Lw/n1x/FdiTTPdMN6SxKLqc8z7xql0MZ+MBzyhsstmIB +RmJWkGisFFAZ51BYB/k9AfLtHjQnvc1yHYBgo0ySG2a6ejkJd2r6U/dvjgbu2dSj +Eo5XJGl//xSSLKs4HPhkuAsdZr2HqPiBwjlFpCd//Fs8he43JBI60+bRSBiUKpQC +ssu6oAj2rvKcy2AMTvjIAlz9Iy3B92fB1Q1JxpbWcLochUca7/NFQTkKMaVeBXxy +i2D+SFWfuBLtcl7p/kbtwqfiDbMCAwEAAaOCAocwggKDMA8GA1UdEwEB/wQFMAMB +Af8wDgYDVR0PAQH/BAQDAgEGMIHoBgNVHSAEgeAwgd0wDwYNYIZIAYb4RQEHFwMB +BjAPBg1ghkgBhvhFAQcXAwEHMA8GDWCGSAGG+EUBBxcDAQgwDwYNYIZIAYb4RQEH +FwMBDTAPBg1ghkgBhvhFAQcXAwEOMA8GDWCGSAGG+EUBBxcDAQ8wDwYNYIZIAYb4 +RQEHFwMBETAPBg1ghkgBhvhFAQcXAwEUMA8GDWCGSAGG+EUBBxcDARcwDwYNYIZI +AYb4RQEHFwMBGDAPBg1ghkgBhvhFAQcXAwEZMA8GDWCGSAGG+EUBBxcDARowDwYN +YIZIAYb4RQEHFwMBGzA4BgNVHR8EMTAvMC2gK6AphidodHRwOi8vc3NwLWNybC52 +ZXJpc2lnbi5jb20vcGNhMy1nMy5jcmwwKAYDVR0RBCEwH6QdMBsxGTAXBgNVBAMT +EFZlcmlTaWduTVBLSS0xLTgwHQYDVR0OBBYEFCwx/8HOq/lN6IkVwGry5atCfUL6 +MIHxBgNVHSMEgekwgeahgdCkgc0wgcoxCzAJBgNVBAYTAlVTMRcwFQYDVQQKEw5W +ZXJpU2lnbiwgSW5jLjEfMB0GA1UECxMWVmVyaVNpZ24gVHJ1c3QgTmV0d29yazE6 +MDgGA1UECxMxKGMpIDE5OTkgVmVyaVNpZ24sIEluYy4gLSBGb3IgYXV0aG9yaXpl +ZCB1c2Ugb25seTFFMEMGA1UEAxM8VmVyaVNpZ24gQ2xhc3MgMyBQdWJsaWMgUHJp +bWFyeSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eSAtIEczghEAm34GSaM+YrnV7pBI +cSnvVzANBgkqhkiG9w0BAQUFAAOCAQEAIS19vzG9j+KXiQ0G1bOuJCeiD9KKW1+8 +69cutvgDf3hEvrw39Gr2ek3cAdso7dvwW0Z17muzpHV08gWTjjKba8mBzjijmgr9 +I2vE2K/Ls72WJvTDUjCAHfBJKeK1q8v7xv1xtf2Jz7BV8sNH3kDB7jhhE++8zLVC +gyFilU0KZfhBpLPVlVYnLozRdvsHfNnO/JskJvRqhDYbeC5ginQT0m5sTQiyTYqL +/IU+i82TxANXjC7syl0dfcGr8pJ85T9bF1EZLxdgikAYLKPGTuXMwOGqT5bR0dKD +lWShiGTRl7HW0KJMg05F0HjOnYpdOYGaFrQghecrkcrRPRevSdFVHQ== +-----END CERTIFICATE----- + EquifaxSecureCA -----BEGIN CERTIFICATE----- MIIDIDCCAomgAwIBAgIENd70zzANBgkqhkiG9w0BAQUFADBOMQswCQYDVQQGEwJV @@ -298,6 +357,54 @@ vEsXCS+0yx5DaMkHJ8HSXPfqIbloEpw8nL+e/IBcm2PN7EeqJSdnoDfzAIJ9VNep +OkuE6N36B9K -----END CERTIFICATE----- +DigiCertAssuredIDRoot +-----BEGIN CERTIFICATE----- +MIIDtzCCAp+gAwIBAgIQDOfg5RfYRv6P5WD8G/AwOTANBgkqhkiG9w0BAQUFADBl +MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 +d3cuZGlnaWNlcnQuY29tMSQwIgYDVQQDExtEaWdpQ2VydCBBc3N1cmVkIElEIFJv +b3QgQ0EwHhcNMDYxMTEwMDAwMDAwWhcNMzExMTEwMDAwMDAwWjBlMQswCQYDVQQG +EwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3d3cuZGlnaWNl +cnQuY29tMSQwIgYDVQQDExtEaWdpQ2VydCBBc3N1cmVkIElEIFJvb3QgQ0EwggEi +MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCtDhXO5EOAXLGH87dg+XESpa7c +JpSIqvTO9SA5KFhgDPiA2qkVlTJhPLWxKISKityfCgyDF3qPkKyK53lTXDGEKvYP +mDI2dsze3Tyoou9q+yHyUmHfnyDXH+Kx2f4YZNISW1/5WBg1vEfNoTb5a3/UsDg+ +wRvDjDPZ2C8Y/igPs6eD1sNuRMBhNZYW/lmci3Zt1/GiSw0r/wty2p5g0I6QNcZ4 +VYcgoc/lbQrISXwxmDNsIumH0DJaoroTghHtORedmTpyoeb6pNnVFzF1roV9Iq4/ +AUaG9ih5yLHa5FcXxH4cDrC0kqZWs72yl+2qp/C3xag/lRbQ/6GW6whfGHdPAgMB +AAGjYzBhMA4GA1UdDwEB/wQEAwIBhjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQW +BBRF66Kv9JLLgjEtUYunpyGd823IDzAfBgNVHSMEGDAWgBRF66Kv9JLLgjEtUYun +pyGd823IDzANBgkqhkiG9w0BAQUFAAOCAQEAog683+Lt8ONyc3pklL/3cmbYMuRC +dWKuh+vy1dneVrOfzM4UKLkNl2BcEkxY5NM9g0lFWJc1aRqoR+pWxnmrEthngYTf +fwk8lOa4JiwgvT2zKIn3X/8i4peEH+ll74fg38FnSbNd67IJKusm7Xi+fT8r87cm +NW1fiQG2SVufAQWbqz0lwcy2f8Lxb4bG+mRo64EtlOtCt/qMHt1i8b5QZ7dsvfPx +H2sMNgcWfzd8qVttevESRmCD1ycEvkvOl77DZypoEd+A5wwzZr8TDRRu838fYxAe ++o0bJW1sj6W3YQGx0qMmoRBxna3iw/nDmVG3KwcIzi7mULKn+gpFL6Lw8g== +-----END CERTIFICATE----- + +DigiCertGlobalRoot +-----BEGIN CERTIFICATE----- +MIIDrzCCApegAwIBAgIQCDvgVpBCRrGhdWrJWZHHSjANBgkqhkiG9w0BAQUFADBh +MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 +d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBD +QTAeFw0wNjExMTAwMDAwMDBaFw0zMTExMTAwMDAwMDBaMGExCzAJBgNVBAYTAlVT +MRUwEwYDVQQKEwxEaWdpQ2VydCBJbmMxGTAXBgNVBAsTEHd3dy5kaWdpY2VydC5j +b20xIDAeBgNVBAMTF0RpZ2lDZXJ0IEdsb2JhbCBSb290IENBMIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4jvhEXLeqKTTo1eqUKKPC3eQyaKl7hLOllsB +CSDMAZOnTjC3U/dDxGkAV53ijSLdhwZAAIEJzs4bg7/fzTtxRuLWZscFs3YnFo97 +nh6Vfe63SKMI2tavegw5BmV/Sl0fvBf4q77uKNd0f3p4mVmFaG5cIzJLv07A6Fpt +43C/dxC//AH2hdmoRBBYMql1GNXRor5H4idq9Joz+EkIYIvUX7Q6hL+hqkpMfT7P +T19sdl6gSzeRntwi5m3OFBqOasv+zbMUZBfHWymeMr/y7vrTC0LUq7dBMtoM1O/4 +gdW7jVg/tRvoSSiicNoxBN33shbyTApOB6jtSj1etX+jkMOvJwIDAQABo2MwYTAO +BgNVHQ8BAf8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUA95QNVbR +TLtm8KPiGxvDl7I90VUwHwYDVR0jBBgwFoAUA95QNVbRTLtm8KPiGxvDl7I90VUw +DQYJKoZIhvcNAQEFBQADggEBAMucN6pIExIK+t1EnE9SsPTfrgT1eXkIoyQY/Esr +hMAtudXH/vTBH1jLuG2cenTnmCmrEbXjcKChzUyImZOMkXDiqw8cvpOp/2PV5Adg +06O/nVsJ8dWO41P0jmP6P6fbtGbfYmbW0W5BjfIttep3Sp+dWOIrWcBAI+0tKIJF +PnlUkiaY4IBIqDfv8NZ5YBberOgOzW6sRBc4L0na4UU+Krk2U886UAb3LujEV0ls +YSEY1QSteDwsOoBrp+uvFRTp2InBuThs4pFsiv9kuXclVzDAGySj4dzp30d8tbQk +CAUw7C29C79Fv1C5qfPrmAESrciIxpg0X40KPMbp1ZWVbd4= +-----END CERTIFICATE----- + Tor1 sha1/juNxSTv9UANmpC9kF5GKpmWNx3Y= Tor2 @@ -1115,6 +1222,29 @@ GwnpXtlR22ciYaQqPEh346B8pt5zohQDhT37qw4wxYMWM4ETCJ57NE7fQMh017l9 lZPvy5TYnh+dXIVtx6quTx8itc2VrbqnzPmrC3p/ -----END CERTIFICATE----- +BaltimoreCyberTrustRoot +-----BEGIN CERTIFICATE----- +MIIDdzCCAl+gAwIBAgIEAgAAuTANBgkqhkiG9w0BAQUFADBaMQswCQYDVQQGEwJJ +RTESMBAGA1UEChMJQmFsdGltb3JlMRMwEQYDVQQLEwpDeWJlclRydXN0MSIwIAYD +VQQDExlCYWx0aW1vcmUgQ3liZXJUcnVzdCBSb290MB4XDTAwMDUxMjE4NDYwMFoX +DTI1MDUxMjIzNTkwMFowWjELMAkGA1UEBhMCSUUxEjAQBgNVBAoTCUJhbHRpbW9y +ZTETMBEGA1UECxMKQ3liZXJUcnVzdDEiMCAGA1UEAxMZQmFsdGltb3JlIEN5YmVy +VHJ1c3QgUm9vdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKMEuyKr +mD1X6CZymrV51Cni4eiVgLGw41uOKymaZN+hXe2wCQVt2yguzmKiYv60iNoS6zjr +IZ3AQSsBUnuId9Mcj8e6uYi1agnnc+gRQKfRzMpijS3ljwumUNKoUMMo6vWrJYeK +mpYcqWe4PwzV9/lSEy/CG9VwcPCPwBLKBsua4dnKM3p31vjsufFoREJIE9LAwqSu +XmD+tqYF/LTdB1kC1FkYmGP1pWPgkAx9XbIGevOF6uvUA65ehD5f/xXtabz5OTZy +dc93Uk3zyZAsuT3lySNTPx8kmCFcB5kpvcY67Oduhjprl3RjM71oGDHweI12v/ye +jl0qhqdNkNwnGjkCAwEAAaNFMEMwHQYDVR0OBBYEFOWdWTCCR1jMrPoIVDaGezq1 +BE3wMBIGA1UdEwEB/wQIMAYBAf8CAQMwDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3 +DQEBBQUAA4IBAQCFDF2O5G9RaEIFoN27TyclhAO992T9Ldcw46QQF+vaKSm2eT92 +9hkTI7gQCvlYpNRhcL0EYWoSihfVCr3FvDB81ukMJY2GQE/szKN+OMY3EU/t3Wgx +jkzSswF07r51XgdIGn9w/xZchMB5hbgF/X++ZRGjD8ACtPhSNzkE1akxehi/oCr0 +Epn3o0WC4zxe9Z2etciefC7IpJ5OCBRLbf1wbWsaY71k5h+3zvDyny67G7fyUIhz +ksLi4xaNmjICq44Y3ekQEe5+NauQrz4wlHrQMz2nZQ/1/I6eYs9HRCwBXbsdtTLS +R9I4LtD+gdwyah617jzV/OeBHRnDJELqYzmp +-----END CERTIFICATE----- + Tor2web -----BEGIN CERTIFICATE----- MIIEgjCCA2qgAwIBAgISESHiIwbyj8tbXjvCF3lADzOxMA0GCSqGSIb3DQEBBQUA diff --git a/chromium/net/http/transport_security_state_static.h b/chromium/net/http/transport_security_state_static.h index 22b5cf957d0..a924550cbec 100644 --- a/chromium/net/http/transport_security_state_static.h +++ b/chromium/net/http/transport_security_state_static.h @@ -42,6 +42,14 @@ static const char kSPKIHash_GoogleG2[] = "\x43\xda\xd6\x30\xee\x53\xf8\xa9\x80\xca" "\x6e\xfd\x85\xf4\x6a\xa3\x79\x90\xe0\xea"; +static const char kSPKIHash_ThawteSGCCA[] = + "\x87\x31\xea\x0e\x3d\xf5\xe8\x70\x3e\x83" + "\x72\x57\x77\xa9\x65\x3b\x3b\xfa\x5e\x14"; + +static const char kSPKIHash_VeriSignClass3SSPIntermediateCA[] = + "\x99\x6a\x20\x6a\x85\x57\x62\xcb\x9a\xf2" + "\x02\x37\xb3\xc0\x69\x5d\xa9\x1e\xc2\x22"; + static const char kSPKIHash_EquifaxSecureCA[] = "\x48\xe6\x68\xf9\x2b\xd2\xb2\x95\xd7\x47" "\xd8\x23\x20\x10\x4f\x33\x98\x90\x9f\xd4"; @@ -78,6 +86,14 @@ static const char kSPKIHash_DigiCertEVRoot[] = "\x83\x31\x7e\x62\x85\x42\x53\xd6\xd7\x78" "\x31\x90\xec\x91\x90\x56\xe9\x91\xb9\xe3"; +static const char kSPKIHash_DigiCertAssuredIDRoot[] = + "\x68\x33\x0e\x61\x35\x85\x21\x59\x29\x83" + "\xa3\xc8\xd2\xd2\xe1\x40\x6e\x7a\xb3\xc1"; + +static const char kSPKIHash_DigiCertGlobalRoot[] = + "\xd5\x2e\x13\xc1\xab\xe3\x49\xda\xe8\xb4" + "\x95\x94\xef\x7c\x38\x43\x60\x64\x66\xbd"; + static const char kSPKIHash_Tor1[] = "\x8e\xe3\x71\x49\x3b\xfd\x50\x03\x66\xa4" "\x2f\x64\x17\x91\x8a\xa6\x65\x8d\xc7\x76"; @@ -218,6 +234,10 @@ static const char kSPKIHash_GTECyberTrustGlobalRoot[] = "\x59\x79\x12\xde\x61\x75\xd6\x6f\xc4\x23" "\xb7\x77\x13\x74\xc7\x96\xde\x6f\x88\x72"; +static const char kSPKIHash_BaltimoreCyberTrustRoot[] = + "\x30\xa4\xe6\x4f\xde\x76\x8a\xfc\xed\x5a" + "\x90\x84\x28\x30\x46\x79\x2c\x29\x15\x70"; + static const char kSPKIHash_Tor2web[] = "\x19\xe5\xb5\x87\x1b\xd4\x83\x2e\xc8\xf5" "\x94\x97\xfe\xc6\x5e\xfb\x48\xe3\x33\xb1"; @@ -264,6 +284,8 @@ static const char* const kGoogleRejectedCerts[] = { kSPKIHash_Intel, kSPKIHash_TCTrustCenter, kSPKIHash_Vodafone, + kSPKIHash_ThawteSGCCA, + kSPKIHash_VeriSignClass3SSPIntermediateCA, NULL, }; #define kGooglePins { \ @@ -303,6 +325,9 @@ static const char* const kTwitterComAcceptableCerts[] = { kSPKIHash_GeoTrustPrimary, kSPKIHash_GeoTrustPrimary_G2, kSPKIHash_GeoTrustPrimary_G3, + kSPKIHash_DigiCertGlobalRoot, + kSPKIHash_DigiCertEVRoot, + kSPKIHash_DigiCertAssuredIDRoot, kSPKIHash_Twitter1, NULL, }; @@ -330,6 +355,9 @@ static const char* const kTwitterCDNAcceptableCerts[] = { kSPKIHash_GeoTrustPrimary, kSPKIHash_GeoTrustPrimary_G2, kSPKIHash_GeoTrustPrimary_G3, + kSPKIHash_DigiCertGlobalRoot, + kSPKIHash_DigiCertEVRoot, + kSPKIHash_DigiCertAssuredIDRoot, kSPKIHash_Twitter1, kSPKIHash_Entrust_2048, kSPKIHash_Entrust_EV, @@ -348,6 +376,7 @@ static const char* const kTwitterCDNAcceptableCerts[] = { kSPKIHash_UTNUSERFirstHardware, kSPKIHash_UTNUSERFirstObject, kSPKIHash_GTECyberTrustGlobalRoot, + kSPKIHash_BaltimoreCyberTrustRoot, NULL, }; #define kTwitterCDNPins { \ @@ -382,7 +411,7 @@ static const char* const kCryptoCatAcceptableCerts[] = { static const struct HSTSPreload kPreloadedSTS[] = { {25, true, "\013pinningtest\007appspot\003com", false, kTestPins, DOMAIN_APPSPOT_COM }, {12, true, "\006google\003com", false, kGooglePins, DOMAIN_GOOGLE_COM }, - {19, true, "\006health\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, + {19, true, "\006wallet\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, {21, true, "\010checkout\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, {19, true, "\006chrome\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, {17, true, "\004docs\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, @@ -414,6 +443,7 @@ static const struct HSTSPreload kPreloadedSTS[] = { {17, true, "\004code\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, {16, true, "\012googlecode\003com", false, kGooglePins, DOMAIN_GOOGLECODE_COM }, {15, true, "\002dl\006google\003com", true, kGooglePins, DOMAIN_GOOGLE_COM }, + {26, true, "\011translate\012googleapis\003com", true, kGooglePins, DOMAIN_GOOGLEAPIS_COM }, {23, true, "\005chart\004apis\006google\003com", false, kGooglePins, DOMAIN_GOOGLE_COM }, {11, true, "\005ytimg\003com", false, kGooglePins, DOMAIN_YTIMG_COM }, {23, true, "\021googleusercontent\003com", false, kGooglePins, DOMAIN_GOOGLEUSERCONTENT_COM }, @@ -834,6 +864,15 @@ static const struct HSTSPreload kPreloadedSTS[] = { {10, true, "\005haste\002ch", true, kNoPins, DOMAIN_NOT_PINNED }, {12, true, "\007mudcrab\002us", true, kNoPins, DOMAIN_NOT_PINNED }, {13, true, "\010mediacru\002sh", true, kNoPins, DOMAIN_NOT_PINNED }, + {13, true, "\010lolicore\002ch", true, kNoPins, DOMAIN_NOT_PINNED }, + {16, true, "\007cloudns\003com\002au", true, kNoPins, DOMAIN_NOT_PINNED }, + {19, true, "\005oplop\007appspot\003com", true, kNoPins, DOMAIN_NOT_PINNED }, + {12, false, "\006bcrook\003com", true, kNoPins, DOMAIN_NOT_PINNED }, + {17, true, "\004wiki\006python\003org", true, kNoPins, DOMAIN_NOT_PINNED }, + {9, false, "\004lumi\002do", true, kNoPins, DOMAIN_NOT_PINNED }, + {22, true, "\020appseccalifornia\003org", true, kNoPins, DOMAIN_NOT_PINNED }, + {17, true, "\013crowdcurity\003com", true, kNoPins, DOMAIN_NOT_PINNED }, + {19, true, "\013saturngames\002co\002uk", true, kNoPins, DOMAIN_NOT_PINNED }, }; static const size_t kNumPreloadedSTS = ARRAYSIZE_UNSAFE(kPreloadedSTS); diff --git a/chromium/net/http/transport_security_state_static.json b/chromium/net/http/transport_security_state_static.json index 1106187da57..67e5a86ccd8 100644 --- a/chromium/net/http/transport_security_state_static.json +++ b/chromium/net/http/transport_security_state_static.json @@ -56,7 +56,9 @@ "Aetna", "Intel", "TCTrustCenter", - "Vodafone" + "Vodafone", + "ThawteSGCCA", + "VeriSignClass3SSPIntermediateCA" ] }, { @@ -90,6 +92,9 @@ "GeoTrustPrimary", "GeoTrustPrimary_G2", "GeoTrustPrimary_G3", + "DigiCertGlobalRoot", + "DigiCertEVRoot", + "DigiCertAssuredIDRoot", "Twitter1" ] }, @@ -114,6 +119,9 @@ "GeoTrustPrimary", "GeoTrustPrimary_G2", "GeoTrustPrimary_G3", + "DigiCertGlobalRoot", + "DigiCertEVRoot", + "DigiCertAssuredIDRoot", "Twitter1", "Entrust_2048", @@ -132,7 +140,8 @@ "UTNUSERFirstClientAuthenticationandEmail", "UTNUSERFirstHardware", "UTNUSERFirstObject", - "GTECyberTrustGlobalRoot" + "GTECyberTrustGlobalRoot", + "BaltimoreCyberTrustRoot" ] }, { @@ -159,7 +168,7 @@ { "name": "google.com", "include_subdomains": true, "pins": "google" }, // Now we force HTTPS for subtrees of google.com. - { "name": "health.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, + { "name": "wallet.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, { "name": "checkout.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, { "name": "chrome.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, { "name": "docs.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, @@ -193,6 +202,7 @@ { "name": "code.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, { "name": "googlecode.com", "include_subdomains": true, "pins": "google" }, { "name": "dl.google.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, + { "name": "translate.googleapis.com", "include_subdomains": true, "mode": "force-https", "pins": "google" }, // chart.apis.google.com is *not* HSTS because the certificate doesn't match // and there are lots of links out there that still use the name. The correct @@ -624,6 +634,15 @@ { "name": "haste.ch", "include_subdomains": true, "mode": "force-https" }, { "name": "mudcrab.us", "include_subdomains": true, "mode": "force-https" }, { "name": "mediacru.sh", "include_subdomains": true, "mode": "force-https" }, + { "name": "lolicore.ch", "include_subdomains": true, "mode": "force-https" }, + { "name": "cloudns.com.au", "include_subdomains": true, "mode": "force-https" }, + { "name": "oplop.appspot.com", "include_subdomains": true, "mode": "force-https" }, + { "name": "bcrook.com", "mode": "force-https" }, + { "name": "wiki.python.org", "include_subdomains": true, "mode": "force-https" }, + { "name": "lumi.do", "mode": "force-https" }, + { "name": "appseccalifornia.org", "include_subdomains": true, "mode": "force-https" }, + { "name": "crowdcurity.com", "include_subdomains": true, "mode": "force-https" }, + { "name": "saturngames.co.uk", "include_subdomains": true, "mode": "force-https" }, // Entries that are only valid if the client supports SNI. { "name": "gmail.com", "mode": "force-https", "pins": "google", "snionly": true }, diff --git a/chromium/net/http/transport_security_state_unittest.cc b/chromium/net/http/transport_security_state_unittest.cc index 99d8c39f042..936d5628249 100644 --- a/chromium/net/http/transport_security_state_unittest.cc +++ b/chromium/net/http/transport_security_state_unittest.cc @@ -240,7 +240,7 @@ TEST_F(TransportSecurityStateTest, Preloaded) { EXPECT_TRUE(ShouldRedirect("chrome.google.com")); EXPECT_TRUE(ShouldRedirect("checkout.google.com")); - EXPECT_TRUE(ShouldRedirect("health.google.com")); + EXPECT_TRUE(ShouldRedirect("wallet.google.com")); EXPECT_TRUE(ShouldRedirect("docs.google.com")); EXPECT_TRUE(ShouldRedirect("sites.google.com")); EXPECT_TRUE(ShouldRedirect("drive.google.com")); @@ -489,7 +489,7 @@ TEST_F(TransportSecurityStateTest, BuiltinCertPins) { EXPECT_TRUE(HasPublicKeyPins("sites.google.com")); EXPECT_TRUE(HasPublicKeyPins("drive.google.com")); EXPECT_TRUE(HasPublicKeyPins("spreadsheets.google.com")); - EXPECT_TRUE(HasPublicKeyPins("health.google.com")); + EXPECT_TRUE(HasPublicKeyPins("wallet.google.com")); EXPECT_TRUE(HasPublicKeyPins("checkout.google.com")); EXPECT_TRUE(HasPublicKeyPins("appengine.google.com")); EXPECT_TRUE(HasPublicKeyPins("market.android.com")); diff --git a/chromium/net/net.gyp b/chromium/net/net.gyp index 9443da6ac8e..b28600d1618 100644 --- a/chromium/net/net.gyp +++ b/chromium/net/net.gyp @@ -360,6 +360,7 @@ 'disk_cache/in_flight_backend_io.h', 'disk_cache/in_flight_io.cc', 'disk_cache/in_flight_io.h', + 'disk_cache/mapped_file.cc', 'disk_cache/mapped_file.h', 'disk_cache/mapped_file_posix.cc', 'disk_cache/mapped_file_avoid_mmap_posix.cc', @@ -389,14 +390,18 @@ 'disk_cache/tracing_cache_backend.h', 'disk_cache/simple/simple_backend_impl.cc', 'disk_cache/simple/simple_backend_impl.h', + 'disk_cache/simple/simple_backend_version.h', 'disk_cache/simple/simple_entry_format.cc', 'disk_cache/simple/simple_entry_format.h', + 'disk_cache/simple/simple_entry_format_history.h', 'disk_cache/simple/simple_entry_impl.cc', 'disk_cache/simple/simple_entry_impl.h', 'disk_cache/simple/simple_entry_operation.cc', 'disk_cache/simple/simple_entry_operation.h', + 'disk_cache/simple/simple_histogram_macros.h' , 'disk_cache/simple/simple_index.cc', 'disk_cache/simple/simple_index.h', + 'disk_cache/simple/simple_index_delegate.h', 'disk_cache/simple/simple_index_file.cc', 'disk_cache/simple/simple_index_file.h', 'disk_cache/simple/simple_index_file_posix.cc', @@ -407,6 +412,8 @@ 'disk_cache/simple/simple_synchronous_entry.h', 'disk_cache/simple/simple_util.cc', 'disk_cache/simple/simple_util.h', + 'disk_cache/simple/simple_version_upgrade.cc', + 'disk_cache/simple/simple_version_upgrade.h', 'disk_cache/flash/flash_entry_impl.cc', 'disk_cache/flash/flash_entry_impl.h', 'disk_cache/flash/format.h', @@ -420,6 +427,8 @@ 'disk_cache/flash/segment.h', 'disk_cache/flash/storage.cc', 'disk_cache/flash/storage.h', + 'disk_cache/v3/block_bitmaps.cc', + 'disk_cache/v3/block_bitmaps.h', 'disk_cache/v3/disk_format_v3.h', 'dns/address_sorter.h', 'dns/address_sorter_posix.cc', @@ -686,7 +695,6 @@ 'proxy/proxy_server_mac.cc', 'proxy/proxy_service.cc', 'proxy/proxy_service.h', - 'quic/blocked_list.h', 'quic/congestion_control/available_channel_estimator.cc', 'quic/congestion_control/available_channel_estimator.h', 'quic/congestion_control/channel_estimator.cc', @@ -784,6 +792,8 @@ 'quic/crypto/strike_register.h', 'quic/crypto/source_address_token.cc', 'quic/crypto/source_address_token.h', + 'quic/quic_ack_notifier.cc', + 'quic/quic_ack_notifier.h', 'quic/quic_alarm.cc', 'quic/quic_alarm.h', 'quic/quic_bandwidth.cc', @@ -808,6 +818,8 @@ 'quic/quic_connection_helper.h', 'quic/quic_connection_logger.cc', 'quic/quic_connection_logger.h', + 'quic/quic_connection_stats.cc', + 'quic/quic_connection_stats.h', 'quic/quic_data_reader.cc', 'quic/quic_data_reader.h', 'quic/quic_data_writer.cc', @@ -818,6 +830,8 @@ 'quic/quic_framer.h', 'quic/quic_http_stream.cc', 'quic/quic_http_stream.h', + 'quic/quic_http_utils.cc', + 'quic/quic_http_utils.h', 'quic/quic_packet_creator.cc', 'quic/quic_packet_creator.h', 'quic/quic_packet_generator.cc', @@ -830,14 +844,14 @@ 'quic/quic_reliable_client_stream.h', 'quic/quic_sent_entropy_manager.cc', 'quic/quic_sent_entropy_manager.h', + 'quic/quic_sent_packet_manager.cc', + 'quic/quic_sent_packet_manager.h', 'quic/quic_session.cc', 'quic/quic_session.h', 'quic/quic_spdy_compressor.cc', 'quic/quic_spdy_compressor.h', 'quic/quic_spdy_decompressor.cc', 'quic/quic_spdy_decompressor.h', - 'quic/quic_stats.cc', - 'quic/quic_stats.h', 'quic/quic_stream_factory.cc', 'quic/quic_stream_factory.h', 'quic/quic_stream_sequencer.cc', @@ -870,6 +884,8 @@ 'socket/nss_ssl_util.cc', 'socket/nss_ssl_util.h', 'socket/server_socket.h', + 'socket/socket_descriptor.cc', + 'socket/socket_descriptor.h', 'socket/socket_net_log_params.cc', 'socket/socket_net_log_params.h', 'socket/socket.h', @@ -900,17 +916,16 @@ 'socket/stream_socket.h', 'socket/tcp_client_socket.cc', 'socket/tcp_client_socket.h', - 'socket/tcp_client_socket_libevent.cc', - 'socket/tcp_client_socket_libevent.h', - 'socket/tcp_client_socket_win.cc', - 'socket/tcp_client_socket_win.h', 'socket/tcp_listen_socket.cc', 'socket/tcp_listen_socket.h', + 'socket/tcp_server_socket.cc', 'socket/tcp_server_socket.h', - 'socket/tcp_server_socket_libevent.cc', - 'socket/tcp_server_socket_libevent.h', - 'socket/tcp_server_socket_win.cc', - 'socket/tcp_server_socket_win.h', + 'socket/tcp_socket.cc', + 'socket/tcp_socket.h', + 'socket/tcp_socket_libevent.cc', + 'socket/tcp_socket_libevent.h', + 'socket/tcp_socket_win.cc', + 'socket/tcp_socket_win.h', 'socket/transport_client_socket_pool.cc', 'socket/transport_client_socket_pool.h', 'socket/unix_domain_socket_posix.cc', @@ -965,6 +980,7 @@ 'spdy/spdy_websocket_stream.h', 'spdy/spdy_write_queue.cc', 'spdy/spdy_write_queue.h', + 'spdy/write_blocked_list.h', 'ssl/client_cert_store.h', 'ssl/client_cert_store_impl.h', 'ssl/client_cert_store_impl_mac.cc', @@ -1087,14 +1103,24 @@ 'url_request/url_request_throttler_manager.h', 'url_request/view_cache_helper.cc', 'url_request/view_cache_helper.h', + 'websockets/websocket_basic_stream.cc', + 'websockets/websocket_basic_stream.h', 'websockets/websocket_channel.cc', 'websockets/websocket_channel.h', + 'websockets/websocket_deflater.h', + 'websockets/websocket_deflater.cc', 'websockets/websocket_errors.cc', 'websockets/websocket_errors.h', + 'websockets/websocket_extension.cc', + 'websockets/websocket_extension.h', + 'websockets/websocket_extension_parser.cc', + 'websockets/websocket_extension_parser.h', 'websockets/websocket_frame.cc', 'websockets/websocket_frame.h', 'websockets/websocket_frame_parser.cc', 'websockets/websocket_frame_parser.h', + 'websockets/websocket_handshake_constants.cc', + 'websockets/websocket_handshake_constants.h', 'websockets/websocket_handshake_handler.cc', 'websockets/websocket_handshake_handler.h', 'websockets/websocket_job.cc', @@ -1360,10 +1386,8 @@ [ 'OS == "win"', { 'sources!': [ 'http/http_auth_handler_ntlm_portable.cc', - 'socket/tcp_client_socket_libevent.cc', - 'socket/tcp_client_socket_libevent.h', - 'socket/tcp_server_socket_libevent.cc', - 'socket/tcp_server_socket_libevent.h', + 'socket/tcp_socket_libevent.cc', + 'socket/tcp_socket_libevent.h', 'ssl/client_cert_store_impl_nss.cc', 'udp/udp_socket_libevent.cc', 'udp/udp_socket_libevent.h', @@ -1563,11 +1587,13 @@ 'disk_cache/simple/simple_test_util.h', 'disk_cache/simple/simple_test_util.cc', 'disk_cache/simple/simple_util_unittest.cc', + 'disk_cache/simple/simple_version_upgrade_unittest.cc', 'disk_cache/storage_block_unittest.cc', 'disk_cache/flash/log_store_entry_unittest.cc', 'disk_cache/flash/log_store_unittest.cc', 'disk_cache/flash/segment_unittest.cc', 'disk_cache/flash/storage_unittest.cc', + 'disk_cache/v3/block_bitmaps_unittest.cc', 'dns/address_sorter_posix_unittest.cc', 'dns/address_sorter_unittest.cc', 'dns/dns_config_service_posix_unittest.cc', @@ -1633,6 +1659,7 @@ 'http/http_security_headers_unittest.cc', 'http/http_server_properties_impl_unittest.cc', 'http/http_status_code_unittest.cc', + 'http/http_stream_factory_impl_request_unittest.cc', 'http/http_stream_factory_impl_unittest.cc', 'http/http_stream_parser_unittest.cc', 'http/http_transaction_unittest.cc', @@ -1668,7 +1695,6 @@ 'proxy/proxy_script_fetcher_impl_unittest.cc', 'proxy/proxy_server_unittest.cc', 'proxy/proxy_service_unittest.cc', - 'quic/blocked_list_test.cc', 'quic/congestion_control/available_channel_estimator_test.cc', 'quic/congestion_control/channel_estimator_test.cc', 'quic/congestion_control/cube_root_test.cc', @@ -1738,6 +1764,7 @@ 'quic/test_tools/simple_quic_framer.h', 'quic/test_tools/test_task_runner.cc', 'quic/test_tools/test_task_runner.h', + 'quic/quic_ack_notifier_test.cc', 'quic/quic_alarm_test.cc', 'quic/quic_bandwidth_test.cc', 'quic/quic_client_session_test.cc', @@ -1752,6 +1779,7 @@ 'quic/quic_fec_group_test.cc', 'quic/quic_framer_test.cc', 'quic/quic_http_stream_test.cc', + 'quic/quic_http_utils_test.cc', 'quic/quic_network_transaction_unittest.cc', 'quic/quic_packet_creator_test.cc', 'quic/quic_packet_generator_test.cc', @@ -1759,6 +1787,7 @@ 'quic/quic_received_packet_manager_test.cc', 'quic/quic_reliable_client_stream_test.cc', 'quic/quic_sent_entropy_manager_test.cc', + 'quic/quic_sent_packet_manager_test.cc', 'quic/quic_session_test.cc', 'quic/quic_spdy_compressor_test.cc', 'quic/quic_spdy_decompressor_test.cc', @@ -1785,6 +1814,7 @@ 'socket/tcp_listen_socket_unittest.cc', 'socket/tcp_listen_socket_unittest.h', 'socket/tcp_server_socket_unittest.cc', + 'socket/tcp_socket_unittest.cc', 'socket/transport_client_socket_pool_unittest.cc', 'socket/transport_client_socket_unittest.cc', 'socket/unix_domain_socket_posix_unittest.cc', @@ -1820,6 +1850,7 @@ 'spdy/spdy_websocket_test_util.cc', 'spdy/spdy_websocket_test_util.h', 'spdy/spdy_write_queue_unittest.cc', + 'spdy/write_blocked_list_test.cc', 'ssl/client_cert_store_impl_unittest.cc', 'ssl/default_server_bound_cert_store_unittest.cc', 'ssl/openssl_client_key_store_unittest.cc', @@ -1854,15 +1885,18 @@ 'url_request/url_request_throttler_unittest.cc', 'url_request/url_request_unittest.cc', 'url_request/view_cache_helper_unittest.cc', + 'websockets/websocket_basic_stream_test.cc', 'websockets/websocket_channel_test.cc', - 'websockets/websocket_errors_unittest.cc', - 'websockets/websocket_frame_parser_unittest.cc', - 'websockets/websocket_frame_unittest.cc', - 'websockets/websocket_handshake_handler_unittest.cc', - 'websockets/websocket_handshake_handler_spdy_unittest.cc', - 'websockets/websocket_job_unittest.cc', - 'websockets/websocket_net_log_params_unittest.cc', - 'websockets/websocket_throttle_unittest.cc', + 'websockets/websocket_deflater_test.cc', + 'websockets/websocket_errors_test.cc', + 'websockets/websocket_extension_parser_test.cc', + 'websockets/websocket_frame_parser_test.cc', + 'websockets/websocket_frame_test.cc', + 'websockets/websocket_handshake_handler_test.cc', + 'websockets/websocket_handshake_handler_spdy_test.cc', + 'websockets/websocket_job_test.cc', + 'websockets/websocket_net_log_params_test.cc', + 'websockets/websocket_throttle_test.cc', ], 'conditions': [ ['os_posix == 1 and OS != "mac" and OS != "ios" and OS != "android"', { @@ -1874,9 +1908,13 @@ 'sources': [ 'tools/flip_server/balsa_frame_test.cc', 'tools/flip_server/balsa_headers_test.cc', + 'tools/flip_server/flip_test_utils.cc', + 'tools/flip_server/flip_test_utils.h', + 'tools/flip_server/http_interface_test.cc', 'tools/flip_server/mem_cache_test.cc', 'tools/flip_server/simple_buffer.cc', 'tools/flip_server/simple_buffer.h', + 'tools/flip_server/spdy_interface_test.cc', 'tools/quic/end_to_end_test.cc', 'tools/quic/quic_client_session_test.cc', 'tools/quic/quic_dispatcher_test.cc', @@ -1885,8 +1923,10 @@ 'tools/quic/quic_in_memory_cache_test.cc', 'tools/quic/quic_reliable_client_stream_test.cc', 'tools/quic/quic_reliable_server_stream_test.cc', + 'tools/quic/quic_server_session_test.cc', 'tools/quic/quic_server_test.cc', 'tools/quic/quic_spdy_server_stream_test.cc', + 'tools/quic/quic_time_wait_list_manager_test.cc', 'tools/quic/test_tools/http_message_test_utils.cc', 'tools/quic/test_tools/http_message_test_utils.h', 'tools/quic/test_tools/mock_epoll_server.cc', @@ -1982,6 +2022,9 @@ 'socket/ssl_client_socket_openssl_unittest.cc', 'ssl/openssl_client_key_store_unittest.cc', ], + 'sources/': [ + ['exclude', '^tools/flip_server'], + ], }, ], [ 'enable_websockets != 1', { @@ -2135,11 +2178,6 @@ '../testing/android/native_test.gyp:native_test_native_code', ] }], - [ 'OS != "win" and OS != "mac"', { - 'sources!': [ - 'cert/x509_cert_types_unittest.cc', - ], - }], ], }, { @@ -2911,6 +2949,21 @@ }, ], }], + ['OS == "android" or OS == "linux"', { + 'targets': [ + { + 'target_name': 'disk_cache_memory_test', + 'type': 'executable', + 'dependencies': [ + '../base/base.gyp:base', + 'net', + ], + 'sources': [ + 'tools/disk_cache_memory_test/disk_cache_memory_test.cc', + ], + }, + ], + }], ['test_isolation_mode != "noop"', { 'targets': [ { diff --git a/chromium/net/net_unittests.isolate b/chromium/net/net_unittests.isolate index 71a2dc05db3..9c88f77493d 100644 --- a/chromium/net/net_unittests.isolate +++ b/chromium/net/net_unittests.isolate @@ -34,14 +34,13 @@ 'variables': { 'isolate_dependency_tracked': [ '../testing/test_env.py', - '../tools/swarm_client/run_isolated.py', - '../tools/swarm_client/googletest/run_test_cases.py', '<(PRODUCT_DIR)/net_unittests<(EXECUTABLE_SUFFIX)', ], 'isolate_dependency_untracked': [ '../third_party/pyftpdlib/', '../third_party/pywebsocket/', '../third_party/tlslite/', + '../tools/swarm_client/', '<(PRODUCT_DIR)/pyproto/', 'tools/testserver/', ], diff --git a/chromium/net/ocsp/nss_ocsp_unittest.cc b/chromium/net/ocsp/nss_ocsp_unittest.cc index be29d5f7931..0530282f4a6 100644 --- a/chromium/net/ocsp/nss_ocsp_unittest.cc +++ b/chromium/net/ocsp/nss_ocsp_unittest.cc @@ -78,7 +78,7 @@ class NssHttpTest : public ::testing::Test { virtual void SetUp() { std::string file_contents; - ASSERT_TRUE(file_util::ReadFileToString( + ASSERT_TRUE(base::ReadFileToString( GetTestCertsDirectory().AppendASCII("aia-intermediate.der"), &file_contents)); ASSERT_FALSE(file_contents.empty()); diff --git a/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.cc b/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.cc index 56e47471bf0..676f6c3a4d6 100644 --- a/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.cc +++ b/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.cc @@ -10,7 +10,7 @@ #include "base/message_loop/message_loop_proxy.h" #include "base/metrics/histogram.h" #include "base/strings/sys_string_conversions.h" -#include "base/threading/worker_pool.h" +#include "base/task_runner.h" #include "base/time/time.h" #include "net/base/net_errors.h" #include "net/proxy/dhcpcsvc_init_win.h" @@ -32,8 +32,10 @@ const int kTimeoutMs = 2000; namespace net { DhcpProxyScriptAdapterFetcher::DhcpProxyScriptAdapterFetcher( - URLRequestContext* url_request_context) - : state_(STATE_START), + URLRequestContext* url_request_context, + scoped_refptr<base::TaskRunner> task_runner) + : task_runner_(task_runner), + state_(STATE_START), result_(ERR_IO_PENDING), url_request_context_(url_request_context) { DCHECK(url_request_context_); @@ -55,7 +57,7 @@ void DhcpProxyScriptAdapterFetcher::Fetch( wait_timer_.Start(FROM_HERE, ImplGetTimeout(), this, &DhcpProxyScriptAdapterFetcher::OnTimeout); scoped_refptr<DhcpQuery> dhcp_query(ImplCreateDhcpQuery()); - base::WorkerPool::PostTaskAndReply( + task_runner_->PostTaskAndReply( FROM_HERE, base::Bind( &DhcpProxyScriptAdapterFetcher::DhcpQuery::GetPacURLForAdapter, @@ -64,8 +66,7 @@ void DhcpProxyScriptAdapterFetcher::Fetch( base::Bind( &DhcpProxyScriptAdapterFetcher::OnDhcpQueryDone, AsWeakPtr(), - dhcp_query), - true); + dhcp_query)); } void DhcpProxyScriptAdapterFetcher::Cancel() { diff --git a/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.h b/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.h index fadf2344656..59597d9a484 100644 --- a/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.h +++ b/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win.h @@ -17,6 +17,10 @@ #include "net/base/net_export.h" #include "url/gurl.h" +namespace base { +class TaskRunner; +} + namespace net { class ProxyScriptFetcher; @@ -29,8 +33,9 @@ class NET_EXPORT_PRIVATE DhcpProxyScriptAdapterFetcher NON_EXPORTED_BASE(public base::NonThreadSafe) { public: // |url_request_context| must outlive DhcpProxyScriptAdapterFetcher. - explicit DhcpProxyScriptAdapterFetcher( - URLRequestContext* url_request_context); + // |task_runner| will be used to post tasks to a thread. + DhcpProxyScriptAdapterFetcher(URLRequestContext* url_request_context, + scoped_refptr<base::TaskRunner> task_runner); virtual ~DhcpProxyScriptAdapterFetcher(); // Starts a fetch. On completion (but not cancellation), |callback| @@ -144,6 +149,9 @@ class NET_EXPORT_PRIVATE DhcpProxyScriptAdapterFetcher void OnFetcherDone(int result); void TransitionToFinish(); + // TaskRunner for posting tasks to a worker thread. + scoped_refptr<base::TaskRunner> task_runner_; + // Current state of this state machine. State state_; diff --git a/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win_unittest.cc b/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win_unittest.cc index be177fa8ab8..17285122e8e 100644 --- a/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win_unittest.cc +++ b/chromium/net/proxy/dhcp_proxy_script_adapter_fetcher_win_unittest.cc @@ -4,9 +4,10 @@ #include "net/proxy/dhcp_proxy_script_adapter_fetcher_win.h" -#include "base/perftimer.h" #include "base/synchronization/waitable_event.h" +#include "base/test/perftimer.h" #include "base/test/test_timeouts.h" +#include "base/threading/sequenced_worker_pool.h" #include "base/timer/timer.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -33,8 +34,10 @@ const char* const kPacUrl = "http://pacserver/script.pac"; class MockDhcpProxyScriptAdapterFetcher : public DhcpProxyScriptAdapterFetcher { public: - explicit MockDhcpProxyScriptAdapterFetcher(URLRequestContext* context) - : DhcpProxyScriptAdapterFetcher(context), + explicit MockDhcpProxyScriptAdapterFetcher( + URLRequestContext* context, + scoped_refptr<base::TaskRunner> task_runner) + : DhcpProxyScriptAdapterFetcher(context, task_runner), dhcp_delay_(base::TimeDelta::FromMilliseconds(1)), timeout_(TestTimeouts::action_timeout()), configured_url_(kPacUrl), @@ -132,8 +135,16 @@ class FetcherClient { public: FetcherClient() : url_request_context_(new TestURLRequestContext()), - fetcher_( - new MockDhcpProxyScriptAdapterFetcher(url_request_context_.get())) { + worker_pool_( + new base::SequencedWorkerPool(4, "DhcpAdapterFetcherTest")), + fetcher_(new MockDhcpProxyScriptAdapterFetcher( + url_request_context_.get(), + worker_pool_->GetTaskRunnerWithShutdownBehavior( + base::SequencedWorkerPool::CONTINUE_ON_SHUTDOWN))) { + } + + ~FetcherClient() { + worker_pool_->Shutdown(); } void WaitForResult(int expected_error) { @@ -151,6 +162,7 @@ class FetcherClient { TestCompletionCallback callback_; scoped_ptr<URLRequestContext> url_request_context_; + scoped_refptr<base::SequencedWorkerPool> worker_pool_; scoped_ptr<MockDhcpProxyScriptAdapterFetcher> fetcher_; base::string16 pac_text_; }; @@ -253,8 +265,9 @@ class MockDhcpRealFetchProxyScriptAdapterFetcher : public MockDhcpProxyScriptAdapterFetcher { public: explicit MockDhcpRealFetchProxyScriptAdapterFetcher( - URLRequestContext* context) - : MockDhcpProxyScriptAdapterFetcher(context), + URLRequestContext* context, + scoped_refptr<base::TaskRunner> task_runner) + : MockDhcpProxyScriptAdapterFetcher(context, task_runner), url_request_context_(context) { } @@ -280,9 +293,12 @@ TEST(DhcpProxyScriptAdapterFetcher, MockDhcpRealFetch) { FetcherClient client; TestURLRequestContext url_request_context; + scoped_refptr<base::TaskRunner> runner = + client.worker_pool_->GetTaskRunnerWithShutdownBehavior( + base::SequencedWorkerPool::CONTINUE_ON_SHUTDOWN); client.fetcher_.reset( new MockDhcpRealFetchProxyScriptAdapterFetcher( - &url_request_context)); + &url_request_context, runner)); client.fetcher_->configured_url_ = configured_url.spec(); client.RunTest(); client.WaitForResult(OK); diff --git a/chromium/net/proxy/dhcp_proxy_script_fetcher_win.cc b/chromium/net/proxy/dhcp_proxy_script_fetcher_win.cc index 9e34f5122e4..ac28e0fc075 100644 --- a/chromium/net/proxy/dhcp_proxy_script_fetcher_win.cc +++ b/chromium/net/proxy/dhcp_proxy_script_fetcher_win.cc @@ -7,8 +7,8 @@ #include "base/bind.h" #include "base/bind_helpers.h" #include "base/metrics/histogram.h" -#include "base/perftimer.h" -#include "base/threading/worker_pool.h" +#include "base/test/perftimer.h" +#include "base/threading/sequenced_worker_pool.h" #include "net/base/net_errors.h" #include "net/proxy/dhcp_proxy_script_adapter_fetcher_win.h" @@ -18,6 +18,27 @@ namespace { +// How many threads to use at maximum to do DHCP lookups. This is +// chosen based on the following UMA data: +// - When OnWaitTimer fires, ~99.8% of users have 6 or fewer network +// adapters enabled for DHCP in total. +// - At the same measurement point, ~99.7% of users have 3 or fewer pending +// DHCP adapter lookups. +// - There is however a very long and thin tail of users who have +// systems reporting up to 100+ adapters (this must be some very weird +// OS bug (?), probably the cause of http://crbug.com/240034). +// +// The maximum number of threads is chosen such that even systems that +// report a huge number of network adapters should not run out of +// memory from this number of threads, while giving a good chance of +// getting back results for any responsive adapters. +// +// The ~99.8% of systems that have 6 or fewer network adapters will +// not grow the thread pool to its maximum size (rather, they will +// grow it to 6 or fewer threads) so setting the limit lower would not +// improve performance or memory usage on those systems. +const int kMaxDhcpLookupThreads = 12; + // How long to wait at maximum after we get results (a PAC file or // knowledge that no PAC file is configured) from whichever network // adapter finishes first. @@ -42,11 +63,16 @@ DhcpProxyScriptFetcherWin::DhcpProxyScriptFetcherWin( destination_string_(NULL), url_request_context_(url_request_context) { DCHECK(url_request_context_); + + worker_pool_ = new base::SequencedWorkerPool(kMaxDhcpLookupThreads, + "PacDhcpLookup"); } DhcpProxyScriptFetcherWin::~DhcpProxyScriptFetcherWin() { // Count as user-initiated if we are not yet in STATE_DONE. Cancel(); + + worker_pool_->Shutdown(); } int DhcpProxyScriptFetcherWin::Fetch(base::string16* utf16_text, @@ -64,7 +90,7 @@ int DhcpProxyScriptFetcherWin::Fetch(base::string16* utf16_text, destination_string_ = utf16_text; last_query_ = ImplCreateAdapterQuery(); - base::WorkerPool::PostTaskAndReply( + GetTaskRunner()->PostTaskAndReply( FROM_HERE, base::Bind( &DhcpProxyScriptFetcherWin::AdapterQuery::GetCandidateAdapterNames, @@ -72,8 +98,7 @@ int DhcpProxyScriptFetcherWin::Fetch(base::string16* utf16_text, base::Bind( &DhcpProxyScriptFetcherWin::OnGetCandidateAdapterNamesDone, AsWeakPtr(), - last_query_), - true); + last_query_)); return ERR_IO_PENDING; } @@ -267,9 +292,15 @@ URLRequestContext* DhcpProxyScriptFetcherWin::url_request_context() const { return url_request_context_; } +scoped_refptr<base::TaskRunner> DhcpProxyScriptFetcherWin::GetTaskRunner() { + return worker_pool_->GetTaskRunnerWithShutdownBehavior( + base::SequencedWorkerPool::CONTINUE_ON_SHUTDOWN); +} + DhcpProxyScriptAdapterFetcher* DhcpProxyScriptFetcherWin::ImplCreateAdapterFetcher() { - return new DhcpProxyScriptAdapterFetcher(url_request_context_); + return new DhcpProxyScriptAdapterFetcher(url_request_context_, + GetTaskRunner()); } DhcpProxyScriptFetcherWin::AdapterQuery* diff --git a/chromium/net/proxy/dhcp_proxy_script_fetcher_win.h b/chromium/net/proxy/dhcp_proxy_script_fetcher_win.h index 79fc4b348ef..d6f14f90ba0 100644 --- a/chromium/net/proxy/dhcp_proxy_script_fetcher_win.h +++ b/chromium/net/proxy/dhcp_proxy_script_fetcher_win.h @@ -16,6 +16,10 @@ #include "base/timer/timer.h" #include "net/proxy/dhcp_proxy_script_fetcher.h" +namespace base { +class SequencedWorkerPool; +} + namespace net { class DhcpProxyScriptAdapterFetcher; @@ -50,6 +54,8 @@ class NET_EXPORT_PRIVATE DhcpProxyScriptFetcherWin URLRequestContext* url_request_context() const; + scoped_refptr<base::TaskRunner> GetTaskRunner(); + // This inner class encapsulate work done on a worker pool thread. // The class calls GetCandidateAdapterNames, which can take a couple of // hundred milliseconds. @@ -161,6 +167,9 @@ class NET_EXPORT_PRIVATE DhcpProxyScriptFetcherWin // Time |Fetch()| was last called, 0 if never. base::TimeTicks fetch_start_time_; + // Worker pool we use for all DHCP lookup tasks. + scoped_refptr<base::SequencedWorkerPool> worker_pool_; + DISALLOW_IMPLICIT_CONSTRUCTORS(DhcpProxyScriptFetcherWin); }; diff --git a/chromium/net/proxy/dhcp_proxy_script_fetcher_win_unittest.cc b/chromium/net/proxy/dhcp_proxy_script_fetcher_win_unittest.cc index 919787a435a..cf0cee05a14 100644 --- a/chromium/net/proxy/dhcp_proxy_script_fetcher_win_unittest.cc +++ b/chromium/net/proxy/dhcp_proxy_script_fetcher_win_unittest.cc @@ -9,8 +9,8 @@ #include "base/bind.h" #include "base/bind_helpers.h" #include "base/message_loop/message_loop.h" -#include "base/perftimer.h" #include "base/rand_util.h" +#include "base/test/perftimer.h" #include "base/test/test_timeouts.h" #include "base/threading/platform_thread.h" #include "net/base/completion_callback.h" @@ -156,9 +156,10 @@ TEST(DhcpProxyScriptFetcherWin, RealFetchWithCancel) { class DelayingDhcpProxyScriptAdapterFetcher : public DhcpProxyScriptAdapterFetcher { public: - explicit DelayingDhcpProxyScriptAdapterFetcher( - URLRequestContext* url_request_context) - : DhcpProxyScriptAdapterFetcher(url_request_context) { + DelayingDhcpProxyScriptAdapterFetcher( + URLRequestContext* url_request_context, + scoped_refptr<base::TaskRunner> task_runner) + : DhcpProxyScriptAdapterFetcher(url_request_context, task_runner) { } class DelayingDhcpQuery : public DhcpQuery { @@ -189,7 +190,8 @@ class DelayingDhcpProxyScriptFetcherWin } DhcpProxyScriptAdapterFetcher* ImplCreateAdapterFetcher() OVERRIDE { - return new DelayingDhcpProxyScriptAdapterFetcher(url_request_context()); + return new DelayingDhcpProxyScriptAdapterFetcher(url_request_context(), + GetTaskRunner()); } }; @@ -212,8 +214,9 @@ TEST(DhcpProxyScriptFetcherWin, RealFetchWithDeferredCancel) { class DummyDhcpProxyScriptAdapterFetcher : public DhcpProxyScriptAdapterFetcher { public: - explicit DummyDhcpProxyScriptAdapterFetcher(URLRequestContext* context) - : DhcpProxyScriptAdapterFetcher(context), + DummyDhcpProxyScriptAdapterFetcher(URLRequestContext* context, + scoped_refptr<base::TaskRunner> runner) + : DhcpProxyScriptAdapterFetcher(context, runner), did_finish_(false), result_(OK), pac_script_(L"bingo"), @@ -297,6 +300,8 @@ class MockDhcpProxyScriptFetcherWin : public DhcpProxyScriptFetcherWin { ResetTestState(); } + using DhcpProxyScriptFetcherWin::GetTaskRunner; + // Adds a fetcher object to the queue of fetchers used by // |ImplCreateAdapterFetcher()|, and its name to the list of adapters // returned by ImplGetCandidateAdapterNames. @@ -312,7 +317,8 @@ class MockDhcpProxyScriptFetcherWin : public DhcpProxyScriptFetcherWin { base::string16 pac_script, base::TimeDelta fetch_delay) { scoped_ptr<DummyDhcpProxyScriptAdapterFetcher> adapter_fetcher( - new DummyDhcpProxyScriptAdapterFetcher(url_request_context())); + new DummyDhcpProxyScriptAdapterFetcher(url_request_context(), + GetTaskRunner())); adapter_fetcher->Configure( did_finish, result, pac_script, fetch_delay.InMilliseconds()); PushBackAdapter(adapter_name, adapter_fetcher.release()); @@ -372,7 +378,7 @@ class MockDhcpProxyScriptFetcherWin : public DhcpProxyScriptFetcherWin { }; class FetcherClient { -public: + public: FetcherClient() : context_(new TestURLRequestContext), fetcher_(context_.get()), @@ -414,6 +420,10 @@ public: fetcher_.ResetTestState(); } + scoped_refptr<base::TaskRunner> GetTaskRunner() { + return fetcher_.GetTaskRunner(); + } + scoped_ptr<URLRequestContext> context_; MockDhcpProxyScriptFetcherWin fetcher_; bool finished_; @@ -426,7 +436,8 @@ public: void TestNormalCaseURLConfiguredOneAdapter(FetcherClient* client) { TestURLRequestContext context; scoped_ptr<DummyDhcpProxyScriptAdapterFetcher> adapter_fetcher( - new DummyDhcpProxyScriptAdapterFetcher(&context)); + new DummyDhcpProxyScriptAdapterFetcher(&context, + client->GetTaskRunner())); adapter_fetcher->Configure(true, OK, L"bingo", 1); client->fetcher_.PushBackAdapter("a", adapter_fetcher.release()); client->RunTest(); @@ -586,7 +597,8 @@ TEST(DhcpProxyScriptFetcherWin, ShortCircuitLessPreferredAdapters) { void TestImmediateCancel(FetcherClient* client) { TestURLRequestContext context; scoped_ptr<DummyDhcpProxyScriptAdapterFetcher> adapter_fetcher( - new DummyDhcpProxyScriptAdapterFetcher(&context)); + new DummyDhcpProxyScriptAdapterFetcher(&context, + client->GetTaskRunner())); adapter_fetcher->Configure(true, OK, L"bingo", 1); client->fetcher_.PushBackAdapter("a", adapter_fetcher.release()); client->RunTest(); diff --git a/chromium/net/proxy/proxy_resolver_perftest.cc b/chromium/net/proxy/proxy_resolver_perftest.cc index 3faf3961b81..12ffd1bd91a 100644 --- a/chromium/net/proxy/proxy_resolver_perftest.cc +++ b/chromium/net/proxy/proxy_resolver_perftest.cc @@ -6,8 +6,8 @@ #include "base/compiler_specific.h" #include "base/file_util.h" #include "base/path_service.h" -#include "base/perftimer.h" #include "base/strings/string_util.h" +#include "base/test/perf_time_logger.h" #include "net/base/net_errors.h" #include "net/dns/mock_host_resolver.h" #include "net/proxy/proxy_info.h" @@ -130,7 +130,7 @@ class PacPerfSuiteRunner { // Start the perf timer. std::string perf_test_name = resolver_name_ + "_" + script_name; - PerfTimeLogger timer(perf_test_name.c_str()); + base::PerfTimeLogger timer(perf_test_name.c_str()); for (int i = 0; i < kNumIterations; ++i) { // Round-robin between URLs to resolve. @@ -163,7 +163,7 @@ class PacPerfSuiteRunner { // Try to read the file from disk. std::string file_contents; - bool ok = file_util::ReadFileToString(path, &file_contents); + bool ok = base::ReadFileToString(path, &file_contents); // If we can't load the file from disk, something is misconfigured. LOG_IF(ERROR, !ok) << "Failed to read file: " << path.value(); diff --git a/chromium/net/proxy/proxy_resolver_v8.cc b/chromium/net/proxy/proxy_resolver_v8.cc index 87f61028039..190a4c0c542 100644 --- a/chromium/net/proxy/proxy_resolver_v8.cc +++ b/chromium/net/proxy/proxy_resolver_v8.cc @@ -184,11 +184,12 @@ v8::Local<v8::String> ASCIILiteralToV8String(const char* ascii) { // Stringizes a V8 object by calling its toString() method. Returns true // on success. This may fail if the toString() throws an exception. bool V8ObjectToUTF16String(v8::Handle<v8::Value> object, - base::string16* utf16_result) { + base::string16* utf16_result, + v8::Isolate* isolate) { if (object.IsEmpty()) return false; - v8::HandleScope scope; + v8::HandleScope scope(isolate); v8::Local<v8::String> str_object = object->ToString(); if (str_object.IsEmpty()) return false; @@ -342,8 +343,8 @@ class ProxyResolverV8::Context { ~Context() { v8::Locker locked(isolate_); - v8_this_.Dispose(isolate_); - v8_context_.Dispose(isolate_); + v8_this_.Dispose(); + v8_context_.Dispose(); } JSBindings* js_bindings() { @@ -507,7 +508,7 @@ class ProxyResolverV8::Context { if (!message.IsEmpty()) { line_number = message->GetLineNumber(); - V8ObjectToUTF16String(message->Get(), &error_message); + V8ObjectToUTF16String(message->Get(), &error_message, isolate_); } js_bindings()->OnError(line_number, error_message); @@ -547,7 +548,7 @@ class ProxyResolverV8::Context { if (args.Length() == 0) { message = ASCIIToUTF16("undefined"); } else { - if (!V8ObjectToUTF16String(args[0], &message)) + if (!V8ObjectToUTF16String(args[0], &message, args.GetIsolate())) return; // toString() threw an exception. } diff --git a/chromium/net/proxy/proxy_resolver_v8_tracing.cc b/chromium/net/proxy/proxy_resolver_v8_tracing.cc index 4f6f5fc17d9..dfea44a3044 100644 --- a/chromium/net/proxy/proxy_resolver_v8_tracing.cc +++ b/chromium/net/proxy/proxy_resolver_v8_tracing.cc @@ -849,6 +849,7 @@ void ProxyResolverV8Tracing::Job::DoDnsOperation() { HostResolver::RequestHandle dns_request = NULL; int result = host_resolver()->Resolve( MakeDnsRequestInfo(pending_dns_host_, pending_dns_op_), + DEFAULT_PRIORITY, &pending_dns_addresses_, base::Bind(&Job::OnDnsOperationComplete, this), &dns_request, diff --git a/chromium/net/proxy/proxy_resolver_v8_tracing_unittest.cc b/chromium/net/proxy/proxy_resolver_v8_tracing_unittest.cc index e597402c89c..805a8734365 100644 --- a/chromium/net/proxy/proxy_resolver_v8_tracing_unittest.cc +++ b/chromium/net/proxy/proxy_resolver_v8_tracing_unittest.cc @@ -49,7 +49,7 @@ scoped_refptr<ProxyResolverScriptData> LoadScriptData(const char* filename) { // Try to read the file from disk. std::string file_contents; - bool ok = file_util::ReadFileToString(path, &file_contents); + bool ok = base::ReadFileToString(path, &file_contents); // If we can't load the file from disk, something is misconfigured. EXPECT_TRUE(ok) << "Failed to read file: " << path.value(); @@ -764,6 +764,7 @@ class BlockableHostResolver : public HostResolver { : num_cancelled_requests_(0), waiting_for_resolve_(false) {} virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, diff --git a/chromium/net/proxy/proxy_resolver_v8_unittest.cc b/chromium/net/proxy/proxy_resolver_v8_unittest.cc index 67e77397a8c..cbbedcec618 100644 --- a/chromium/net/proxy/proxy_resolver_v8_unittest.cc +++ b/chromium/net/proxy/proxy_resolver_v8_unittest.cc @@ -117,7 +117,7 @@ class ProxyResolverV8WithMockBindings : public ProxyResolverV8 { // Try to read the file from disk. std::string file_contents; - bool ok = file_util::ReadFileToString(path, &file_contents); + bool ok = base::ReadFileToString(path, &file_contents); // If we can't load the file from disk, something is misconfigured. if (!ok) { diff --git a/chromium/net/proxy/proxy_script_decider.cc b/chromium/net/proxy/proxy_script_decider.cc index 38bf751cd4d..3db33fcc89a 100644 --- a/chromium/net/proxy/proxy_script_decider.cc +++ b/chromium/net/proxy/proxy_script_decider.cc @@ -9,6 +9,7 @@ #include "base/compiler_specific.h" #include "base/format_macros.h" #include "base/logging.h" +#include "base/metrics/histogram.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/values.h" @@ -16,6 +17,7 @@ #include "net/proxy/dhcp_proxy_script_fetcher.h" #include "net/proxy/dhcp_proxy_script_fetcher_factory.h" #include "net/proxy/proxy_script_fetcher.h" +#include "net/url_request/url_request_context.h" namespace net { @@ -45,7 +47,10 @@ bool LooksLikePacScript(const base::string16& script) { // // For more details, also check out this comment: // http://code.google.com/p/chromium/issues/detail?id=18575#c20 -static const char kWpadUrl[] = "http://wpad/wpad.dat"; +namespace { +const char kWpadUrl[] = "http://wpad/wpad.dat"; +const int kQuickCheckDelayMs = 1000; +}; base::Value* ProxyScriptDecider::PacSource::NetLogCallback( const GURL* effective_pac_url, @@ -82,6 +87,12 @@ ProxyScriptDecider::ProxyScriptDecider( net_log_(BoundNetLog::Make( net_log, NetLog::SOURCE_PROXY_SCRIPT_DECIDER)), fetch_pac_bytes_(false) { + if (proxy_script_fetcher && + proxy_script_fetcher->GetRequestContext() && + proxy_script_fetcher->GetRequestContext()->host_resolver()) { + host_resolver_.reset(new SingleRequestHostResolver( + proxy_script_fetcher->GetRequestContext()->host_resolver())); + } } ProxyScriptDecider::~ProxyScriptDecider() { @@ -106,6 +117,7 @@ int ProxyScriptDecider::Start( wait_delay_ = base::TimeDelta(); pac_mandatory_ = config.pac_mandatory(); + have_custom_pac_url_ = config.has_pac_url(); pac_sources_ = BuildPacSourcesFallbackList(config); DCHECK(!pac_sources_.empty()); @@ -172,6 +184,13 @@ int ProxyScriptDecider::DoLoop(int result) { case STATE_WAIT_COMPLETE: rv = DoWaitComplete(rv); break; + case STATE_QUICK_CHECK: + DCHECK_EQ(OK, rv); + rv = DoQuickCheck(); + break; + case STATE_QUICK_CHECK_COMPLETE: + rv = DoQuickCheckComplete(rv); + break; case STATE_FETCH_PAC_SCRIPT: DCHECK_EQ(OK, rv); rv = DoFetchPacScript(); @@ -225,6 +244,60 @@ int ProxyScriptDecider::DoWaitComplete(int result) { return OK; } +int ProxyScriptDecider::DoQuickCheck() { + if (host_resolver_.get() == NULL) { + // If we have no resolver, skip QuickCheck altogether. + next_state_ = GetStartState(); + return OK; + } + + if (have_custom_pac_url_) { + // If there's a custom URL, skip QuickCheck. + next_state_ = GetStartState(); + return OK; + } + + quick_check_start_time_ = base::Time::Now(); + HostResolver::RequestInfo reqinfo(HostPortPair("wpad", 80)); + reqinfo.set_host_resolver_flags(HOST_RESOLVER_SYSTEM_ONLY); + CompletionCallback callback = base::Bind( + &ProxyScriptDecider::OnIOCompletion, + base::Unretained(this)); + + + // We use HIGHEST here because proxy decision blocks doing any other requests. + int rv = host_resolver_->Resolve(reqinfo, HIGHEST, &wpad_addresses_, + callback, net_log_); + + // We can't get an error response - the name is known to be valid, and we + // don't cache negative dns responses. + DCHECK(rv == OK || rv == ERR_IO_PENDING); + + if (rv == OK) { + next_state_ = GetStartState(); + } else { + quick_check_timer_.Start(FROM_HERE, + base::TimeDelta::FromMilliseconds( + kQuickCheckDelayMs), + base::Bind(callback, ERR_NAME_NOT_RESOLVED)); + next_state_ = STATE_QUICK_CHECK_COMPLETE; + } + return rv; +} + +int ProxyScriptDecider::DoQuickCheckComplete(int result) { + base::TimeDelta delta = base::Time::Now() - quick_check_start_time_; + if (result == OK) + UMA_HISTOGRAM_TIMES("Net.WpadQuickCheckSuccess", delta); + else + UMA_HISTOGRAM_TIMES("Net.WpadQuickCheckFailure", delta); + host_resolver_->Cancel(); + quick_check_timer_.Stop(); + if (result == OK) + next_state_ = GetStartState(); + return result; +} + int ProxyScriptDecider::DoFetchPacScript() { DCHECK(fetch_pac_bytes_); diff --git a/chromium/net/proxy/proxy_script_decider.h b/chromium/net/proxy/proxy_script_decider.h index 9a77938ec8e..23fa7af85c2 100644 --- a/chromium/net/proxy/proxy_script_decider.h +++ b/chromium/net/proxy/proxy_script_decider.h @@ -12,9 +12,12 @@ #include "base/strings/string16.h" #include "base/time/time.h" #include "base/timer/timer.h" +#include "net/base/address_list.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" #include "net/base/net_log.h" +#include "net/dns/host_resolver.h" +#include "net/dns/single_request_host_resolver.h" #include "net/proxy/proxy_config.h" #include "net/proxy/proxy_resolver.h" #include "url/gurl.h" @@ -105,6 +108,8 @@ class NET_EXPORT_PRIVATE ProxyScriptDecider { STATE_NONE, STATE_WAIT, STATE_WAIT_COMPLETE, + STATE_QUICK_CHECK, + STATE_QUICK_CHECK_COMPLETE, STATE_FETCH_PAC_SCRIPT, STATE_FETCH_PAC_SCRIPT_COMPLETE, STATE_VERIFY_PAC_SCRIPT, @@ -121,6 +126,9 @@ class NET_EXPORT_PRIVATE ProxyScriptDecider { int DoWait(); int DoWaitComplete(int result); + int DoQuickCheck(); + int DoQuickCheckComplete(int result); + int DoFetchPacScript(); int DoFetchPacScriptComplete(int result); @@ -161,6 +169,9 @@ class NET_EXPORT_PRIVATE ProxyScriptDecider { // (i.e. fallback to direct connections are prohibited). bool pac_mandatory_; + // Whether we have an existing custom PAC URL. + bool have_custom_pac_url_; + PacSourceList pac_sources_; State next_state_; @@ -175,6 +186,10 @@ class NET_EXPORT_PRIVATE ProxyScriptDecider { ProxyConfig effective_config_; scoped_refptr<ProxyResolverScriptData> script_data_; + AddressList wpad_addresses_; + base::OneShotTimer<ProxyScriptDecider> quick_check_timer_; + scoped_ptr<SingleRequestHostResolver> host_resolver_; + base::Time quick_check_start_time_; DISALLOW_COPY_AND_ASSIGN(ProxyScriptDecider); }; diff --git a/chromium/net/proxy/proxy_script_decider_unittest.cc b/chromium/net/proxy/proxy_script_decider_unittest.cc index 977bd4df642..18d8c6f5dcd 100644 --- a/chromium/net/proxy/proxy_script_decider_unittest.cc +++ b/chromium/net/proxy/proxy_script_decider_unittest.cc @@ -7,6 +7,7 @@ #include "base/bind.h" #include "base/memory/weak_ptr.h" #include "base/message_loop/message_loop.h" +#include "base/run_loop.h" #include "base/strings/string_util.h" #include "base/strings/utf_string_conversions.h" #include "base/time/time.h" @@ -14,11 +15,13 @@ #include "net/base/net_log.h" #include "net/base/net_log_unittest.h" #include "net/base/test_completion_callback.h" +#include "net/dns/mock_host_resolver.h" #include "net/proxy/dhcp_proxy_script_fetcher.h" #include "net/proxy/proxy_config.h" #include "net/proxy/proxy_resolver.h" #include "net/proxy/proxy_script_decider.h" #include "net/proxy/proxy_script_fetcher.h" +#include "net/url_request/url_request_context.h" #include "testing/gtest/include/gtest/gtest.h" namespace net { @@ -93,7 +96,12 @@ class Rules { class RuleBasedProxyScriptFetcher : public ProxyScriptFetcher { public: - explicit RuleBasedProxyScriptFetcher(const Rules* rules) : rules_(rules) {} + explicit RuleBasedProxyScriptFetcher(const Rules* rules) + : rules_(rules), request_context_(NULL) {} + + virtual void SetRequestContext(URLRequestContext* context) { + request_context_ = context; + } // ProxyScriptFetcher implementation. virtual int Fetch(const GURL& url, @@ -109,10 +117,13 @@ class RuleBasedProxyScriptFetcher : public ProxyScriptFetcher { virtual void Cancel() OVERRIDE {} - virtual URLRequestContext* GetRequestContext() const OVERRIDE { return NULL; } + virtual URLRequestContext* GetRequestContext() const OVERRIDE { + return request_context_; + } private: const Rules* rules_; + URLRequestContext* request_context_; }; // Succeed using custom PAC script. @@ -243,6 +254,93 @@ TEST(ProxyScriptDeciderTest, AutodetectSuccess) { EXPECT_EQ(rule.url, decider.effective_config().pac_url()); } +class ProxyScriptDeciderQuickCheckTest : public ::testing::Test { + public: + ProxyScriptDeciderQuickCheckTest() + : rule_(rules_.AddSuccessRule("http://wpad/wpad.dat")), + fetcher_(&rules_) { } + + virtual void SetUp() OVERRIDE { + request_context_.set_host_resolver(&resolver_); + fetcher_.SetRequestContext(&request_context_); + config_.set_auto_detect(true); + decider_.reset(new ProxyScriptDecider(&fetcher_, &dhcp_fetcher_, NULL)); + } + + int StartDecider() { + return decider_->Start(config_, base::TimeDelta(), true, + callback_.callback()); + } + + protected: + scoped_ptr<ProxyScriptDecider> decider_; + MockHostResolver resolver_; + Rules rules_; + Rules::Rule rule_; + TestCompletionCallback callback_; + + private: + URLRequestContext request_context_; + + RuleBasedProxyScriptFetcher fetcher_; + DoNothingDhcpProxyScriptFetcher dhcp_fetcher_; + + ProxyConfig config_; +}; + +#if 0 +// Fails if a synchronous DNS lookup success for wpad causes QuickCheck to fail. +TEST_F(ProxyScriptDeciderQuickCheckTest, SyncSuccess) { + resolver_.set_synchronous_mode(true); + resolver_.rules()->AddRule("wpad", "1.2.3.4"); + + EXPECT_EQ(OK, StartDecider()); + EXPECT_EQ(rule_.text(), decider_->script_data()->utf16()); + + EXPECT_TRUE(decider_->effective_config().has_pac_url()); + EXPECT_EQ(rule_.url, decider_->effective_config().pac_url()); +} + +// Fails if an asynchronous DNS lookup success for wpad causes QuickCheck to +// fail. +TEST_F(ProxyScriptDeciderQuickCheckTest, AsyncSuccess) { + resolver_.set_ondemand_mode(true); + resolver_.rules()->AddRule("wpad", "1.2.3.4"); + + EXPECT_EQ(ERR_IO_PENDING, StartDecider()); + ASSERT_TRUE(resolver_.has_pending_requests()); + resolver_.ResolveAllPending(); + callback_.WaitForResult(); + EXPECT_FALSE(resolver_.has_pending_requests()); + EXPECT_EQ(rule_.text(), decider_->script_data()->utf16()); + EXPECT_TRUE(decider_->effective_config().has_pac_url()); + EXPECT_EQ(rule_.url, decider_->effective_config().pac_url()); +} + +// Fails if an asynchronous DNS lookup failure (i.e. an NXDOMAIN) still causes +// ProxyScriptDecider to yield a PAC URL. +TEST_F(ProxyScriptDeciderQuickCheckTest, AsyncFail) { + resolver_.set_ondemand_mode(true); + resolver_.rules()->AddSimulatedFailure("wpad"); + EXPECT_EQ(ERR_IO_PENDING, StartDecider()); + ASSERT_TRUE(resolver_.has_pending_requests()); + resolver_.ResolveAllPending(); + callback_.WaitForResult(); + EXPECT_FALSE(decider_->effective_config().has_pac_url()); +} + +// Fails if a DNS lookup timeout either causes ProxyScriptDecider to yield a PAC +// URL or causes ProxyScriptDecider not to cancel its pending resolution. +TEST_F(ProxyScriptDeciderQuickCheckTest, AsyncTimeout) { + resolver_.set_ondemand_mode(true); + EXPECT_EQ(ERR_IO_PENDING, StartDecider()); + ASSERT_TRUE(resolver_.has_pending_requests()); + callback_.WaitForResult(); + EXPECT_FALSE(resolver_.has_pending_requests()); + EXPECT_FALSE(decider_->effective_config().has_pac_url()); +} +#endif + // Fails at WPAD (downloading), but succeeds in choosing the custom PAC. TEST(ProxyScriptDeciderTest, AutodetectFailCustomSuccess1) { Rules rules; diff --git a/chromium/net/proxy/proxy_script_fetcher_impl.cc b/chromium/net/proxy/proxy_script_fetcher_impl.cc index 2bf9e667792..f8925fa9c3c 100644 --- a/chromium/net/proxy/proxy_script_fetcher_impl.cc +++ b/chromium/net/proxy/proxy_script_fetcher_impl.cc @@ -211,7 +211,7 @@ void ProxyScriptFetcherImpl::OnResponseStarted(URLRequest* request) { } // Require HTTP responses to have a success status code. - if (request->url().SchemeIs("http") || request->url().SchemeIs("https")) { + if (request->url().SchemeIsHTTPOrHTTPS()) { // NOTE about status codes: We are like Firefox 3 in this respect. // {IE 7, Safari 3, Opera 9.5} do not care about the status code. if (request->GetResponseCode() != 200) { diff --git a/chromium/net/proxy/proxy_script_fetcher_impl_unittest.cc b/chromium/net/proxy/proxy_script_fetcher_impl_unittest.cc index 8d42514a957..9c1ca98a764 100644 --- a/chromium/net/proxy/proxy_script_fetcher_impl_unittest.cc +++ b/chromium/net/proxy/proxy_script_fetcher_impl_unittest.cc @@ -71,7 +71,8 @@ class RequestContext : public URLRequestContext { storage_.set_http_transaction_factory(new HttpCache( network_session.get(), HttpCache::DefaultBackend::InMemory(0))); URLRequestJobFactoryImpl* job_factory = new URLRequestJobFactoryImpl(); - job_factory->SetProtocolHandler("file", new FileProtocolHandler()); + job_factory->SetProtocolHandler( + "file", new FileProtocolHandler(base::MessageLoopProxy::current())); storage_.set_job_factory(job_factory); } diff --git a/chromium/net/quic/blocked_list.h b/chromium/net/quic/blocked_list.h deleted file mode 100644 index 3a7f989eeec..00000000000 --- a/chromium/net/quic/blocked_list.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. -// -// A combined list/hash set for read or write-blocked entities. - -#ifndef NET_QUIC_BLOCKED_LIST_H_ -#define NET_QUIC_BLOCKED_LIST_H_ - -#include <list> - -#include "base/containers/hash_tables.h" -#include "base/logging.h" - -namespace net { - -template <typename Object> -class BlockedList { - public: - // Called to add an object to the blocked list. This indicates - // the object should be notified when it can use the socket again. - // - // If this object is already on the list, it will not be added again. - void AddBlockedObject(Object object) { - // Only add the object to the list if we successfully add it to the set. - if (object_set_.insert(object).second) { - object_list_.push_back(object); - } - } - - // Called to remove an object from a blocked list. This should be - // called in the event the object is being deleted before the list is. - void RemoveBlockedObject(Object object) { - // Remove the object from the set. We'll check the set before calling - // OnCanWrite on a object from the list. - // - // There is potentially ordering unfairness should a session be removed and - // then readded (as it keeps its position in the list) but it's not worth - // the overhead to walk the list and remove it. - object_set_.erase(object); - } - - // Called when the socket is usable and some objects can access it. Returns - // the first object and removes it from the list. - Object GetNextBlockedObject() { - DCHECK(!IsEmpty()); - - // Walk the list to find the first object which was not removed from the - // set. - while (!object_list_.empty()) { - Object object = *object_list_.begin(); - object_list_.pop_front(); - int removed = object_set_.erase(object); - if (removed > 0) { - return object; - } - } - - // This is a bit of a hack: It's illegal to call GetNextBlockedObject() if - // the list is empty (see DCHECK above) but we must return something. This - // compiles for ints (returns 0) and pointers in the case that someone has a - // bug in their call site. - return 0; - }; - - // Returns the number of objects in the blocked list. - int NumObjects() { - return object_set_.size(); - }; - - // Returns true if there are no objects in the list, false otherwise. - bool IsEmpty() { - return object_set_.empty(); - }; - - private: - // A set tracking the objects. This is the authoritative container for - // determining if an object is blocked. Objects in the list will always - // be in the set. - base::hash_set<Object> object_set_; - // A list tracking the order in which objects were added to the list. - // Objects are added to the back and pulled off the front, but only get - // resumption calls if they're still in the set. - // It's possible to be in the list twice, but only the first entry will get an - // OnCanWrite call. - std::list<Object> object_list_; -}; - -} // namespace net - -#endif // NET_QUIC_BLOCKED_LIST_H_ diff --git a/chromium/net/quic/blocked_list_test.cc b/chromium/net/quic/blocked_list_test.cc deleted file mode 100644 index 074b6f52782..00000000000 --- a/chromium/net/quic/blocked_list_test.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/quic/blocked_list.h" -#include "net/quic/quic_connection.h" -#include "testing/gtest/include/gtest/gtest.h" - -#if defined(COMPILER_GCC) -namespace BASE_HASH_NAMESPACE { -template<> -struct hash<const int*> { - std::size_t operator()(const int* ptr) const { - return hash<size_t>()(reinterpret_cast<size_t>(ptr)); - } -}; -} -#endif - -namespace net { -namespace test { -namespace { - -class BlockedListTest : public ::testing::Test { - protected: - BlockedListTest() : - item1_(0), - item2_(0), - item3_(0) { - } - - BlockedList<const int*> list_; - const int item1_; - const int item2_; - const int item3_; -}; - -TEST_F(BlockedListTest, BasicAdd) { - list_.AddBlockedObject(&item1_); - list_.AddBlockedObject(&item3_); - list_.AddBlockedObject(&item2_); - ASSERT_EQ(3, list_.NumObjects()); - ASSERT_FALSE(list_.IsEmpty()); - - EXPECT_EQ(&item1_, list_.GetNextBlockedObject()); - EXPECT_EQ(&item3_, list_.GetNextBlockedObject()); - EXPECT_EQ(&item2_, list_.GetNextBlockedObject()); -} - -TEST_F(BlockedListTest, AddAndRemove) { - list_.AddBlockedObject(&item1_); - list_.AddBlockedObject(&item3_); - list_.AddBlockedObject(&item2_); - ASSERT_EQ(3, list_.NumObjects()); - - list_.RemoveBlockedObject(&item3_); - ASSERT_EQ(2, list_.NumObjects()); - - EXPECT_EQ(&item1_, list_.GetNextBlockedObject()); - EXPECT_EQ(&item2_, list_.GetNextBlockedObject()); -} - -TEST_F(BlockedListTest, DuplicateAdd) { - list_.AddBlockedObject(&item1_); - list_.AddBlockedObject(&item3_); - list_.AddBlockedObject(&item2_); - - list_.AddBlockedObject(&item3_); - list_.AddBlockedObject(&item2_); - list_.AddBlockedObject(&item1_); - - ASSERT_EQ(3, list_.NumObjects()); - ASSERT_FALSE(list_.IsEmpty()); - - // Call in the original insert order. - EXPECT_EQ(&item1_, list_.GetNextBlockedObject()); - EXPECT_EQ(&item3_, list_.GetNextBlockedObject()); - EXPECT_EQ(&item2_, list_.GetNextBlockedObject()); -} - -} // namespace -} // namespace test -} // namespace net diff --git a/chromium/net/quic/congestion_control/fix_rate_sender.cc b/chromium/net/quic/congestion_control/fix_rate_sender.cc index dff52cf305d..99aa10fabf8 100644 --- a/chromium/net/quic/congestion_control/fix_rate_sender.cc +++ b/chromium/net/quic/congestion_control/fix_rate_sender.cc @@ -60,15 +60,18 @@ void FixRateSender::OnIncomingLoss(QuicTime /*ack_receive_time*/) { // Ignore losses for fix rate sender. } -void FixRateSender::SentPacket(QuicTime sent_time, - QuicPacketSequenceNumber /*sequence_number*/, - QuicByteCount bytes, - Retransmission is_retransmission) { +bool FixRateSender::SentPacket( + QuicTime sent_time, + QuicPacketSequenceNumber /*sequence_number*/, + QuicByteCount bytes, + Retransmission is_retransmission, + HasRetransmittableData /*has_retransmittable_data*/) { fix_rate_leaky_bucket_.Add(sent_time, bytes); paced_sender_.SentPacket(sent_time, bytes); if (is_retransmission == NOT_RETRANSMISSION) { data_in_flight_ += bytes; } + return true; } void FixRateSender::AbandoningPacket( @@ -80,7 +83,7 @@ QuicTime::Delta FixRateSender::TimeUntilSend( QuicTime now, Retransmission /*is_retransmission*/, HasRetransmittableData /*has_retransmittable_data*/, - IsHandshake /* handshake */) { + IsHandshake /*handshake*/) { if (CongestionWindow() > fix_rate_leaky_bucket_.BytesPending(now)) { if (CongestionWindow() <= data_in_flight_) { // We need an ack before we send more. diff --git a/chromium/net/quic/congestion_control/fix_rate_sender.h b/chromium/net/quic/congestion_control/fix_rate_sender.h index 38cebad165f..781deade1c0 100644 --- a/chromium/net/quic/congestion_control/fix_rate_sender.h +++ b/chromium/net/quic/congestion_control/fix_rate_sender.h @@ -32,10 +32,12 @@ class NET_EXPORT_PRIVATE FixRateSender : public SendAlgorithmInterface { QuicByteCount acked_bytes, QuicTime::Delta rtt) OVERRIDE; virtual void OnIncomingLoss(QuicTime ack_receive_time) OVERRIDE; - virtual void SentPacket(QuicTime sent_time, - QuicPacketSequenceNumber equence_number, - QuicByteCount bytes, - Retransmission is_retransmission) OVERRIDE; + virtual bool SentPacket( + QuicTime sent_time, + QuicPacketSequenceNumber equence_number, + QuicByteCount bytes, + Retransmission is_retransmission, + HasRetransmittableData has_retransmittable_data) OVERRIDE; virtual void AbandoningPacket(QuicPacketSequenceNumber sequence_number, QuicByteCount abandoned_bytes) OVERRIDE; virtual QuicTime::Delta TimeUntilSend( diff --git a/chromium/net/quic/congestion_control/fix_rate_test.cc b/chromium/net/quic/congestion_control/fix_rate_test.cc index f914ed671a6..e316f6517a7 100644 --- a/chromium/net/quic/congestion_control/fix_rate_test.cc +++ b/chromium/net/quic/congestion_control/fix_rate_test.cc @@ -63,11 +63,14 @@ TEST_F(FixRateTest, SenderAPI) { EXPECT_EQ(300000, sender_->BandwidthEstimate().ToBytesPerSecond()); EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); - sender_->SentPacket(clock_.Now(), 1, kMaxPacketSize, NOT_RETRANSMISSION); + sender_->SentPacket(clock_.Now(), 1, kMaxPacketSize, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); - sender_->SentPacket(clock_.Now(), 2, kMaxPacketSize, NOT_RETRANSMISSION); - sender_->SentPacket(clock_.Now(), 3, 600, NOT_RETRANSMISSION); + sender_->SentPacket(clock_.Now(), 2, kMaxPacketSize, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); + sender_->SentPacket(clock_.Now(), 3, 600, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); EXPECT_EQ(QuicTime::Delta::FromMilliseconds(10), sender_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE)); @@ -98,12 +101,12 @@ TEST_F(FixRateTest, FixRatePacing) { NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); sender_->SentPacket(clock_.Now(), sequence_number++, packet_size, - NOT_RETRANSMISSION); + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA); EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); sender_->SentPacket(clock_.Now(), sequence_number++, packet_size, - NOT_RETRANSMISSION); + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA); QuicTime::Delta advance_time = sender_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE); clock_.AdvanceTime(advance_time); diff --git a/chromium/net/quic/congestion_control/hybrid_slow_start.cc b/chromium/net/quic/congestion_control/hybrid_slow_start.cc index 8968dc94775..eee96ad5bb3 100644 --- a/chromium/net/quic/congestion_control/hybrid_slow_start.cc +++ b/chromium/net/quic/congestion_control/hybrid_slow_start.cc @@ -104,9 +104,4 @@ bool HybridSlowStart::Exit() { return false; } -QuicTime::Delta HybridSlowStart::SmoothedRtt() { - // TODO(satyamshekhar): Calculate and return smooth average of rtt over time. - return current_rtt_; -} - } // namespace net diff --git a/chromium/net/quic/congestion_control/hybrid_slow_start.h b/chromium/net/quic/congestion_control/hybrid_slow_start.h index b0e424883da..cee9c731257 100644 --- a/chromium/net/quic/congestion_control/hybrid_slow_start.h +++ b/chromium/net/quic/congestion_control/hybrid_slow_start.h @@ -46,8 +46,6 @@ class NET_EXPORT_PRIVATE HybridSlowStart { bool started() { return started_; } - QuicTime::Delta SmoothedRtt(); - private: const QuicClock* clock_; bool started_; diff --git a/chromium/net/quic/congestion_control/inter_arrival_overuse_detector.cc b/chromium/net/quic/congestion_control/inter_arrival_overuse_detector.cc index 73e005d788a..ea1c3afebfa 100644 --- a/chromium/net/quic/congestion_control/inter_arrival_overuse_detector.cc +++ b/chromium/net/quic/congestion_control/inter_arrival_overuse_detector.cc @@ -16,10 +16,6 @@ static const int kMinVarianceDelta = 10000; // Threshold for accumulated delta. static const int kThresholdAccumulatedDeltasUs = 1000; -// The higher the beta parameter, the lower is the effect of the input and the -// more damping of the noise. And the longer time for a detection. -static const float kBeta = 0.98f; - // Same as above, described as numerator and denominator. static const int kBetaNumerator = 49; static const int kBetaDenominator = 50; diff --git a/chromium/net/quic/congestion_control/inter_arrival_sender.cc b/chromium/net/quic/congestion_control/inter_arrival_sender.cc index 1aa7ab90e9b..5640a731ae9 100644 --- a/chromium/net/quic/congestion_control/inter_arrival_sender.cc +++ b/chromium/net/quic/congestion_control/inter_arrival_sender.cc @@ -235,14 +235,17 @@ void InterArrivalSender::OnIncomingLoss(QuicTime ack_receive_time) { } } -void InterArrivalSender::SentPacket(QuicTime sent_time, - QuicPacketSequenceNumber sequence_number, - QuicByteCount bytes, - Retransmission /*retransmit*/) { +bool InterArrivalSender::SentPacket( + QuicTime sent_time, + QuicPacketSequenceNumber sequence_number, + QuicByteCount bytes, + Retransmission /*is_retransmit*/, + HasRetransmittableData /*has_retransmittable_data*/) { if (probing_) { probe_->OnSentPacket(bytes); } paced_sender_->SentPacket(sent_time, bytes); + return true; } void InterArrivalSender::AbandoningPacket( @@ -258,7 +261,7 @@ QuicTime::Delta InterArrivalSender::TimeUntilSend( QuicTime now, Retransmission /*retransmit*/, HasRetransmittableData has_retransmittable_data, - IsHandshake /* handshake */) { + IsHandshake /*handshake*/) { // TODO(pwestin): implement outer_congestion_window_ logic. QuicTime::Delta outer_window = QuicTime::Delta::Zero(); diff --git a/chromium/net/quic/congestion_control/inter_arrival_sender.h b/chromium/net/quic/congestion_control/inter_arrival_sender.h index ad28ecd215a..2c455cc9d13 100644 --- a/chromium/net/quic/congestion_control/inter_arrival_sender.h +++ b/chromium/net/quic/congestion_control/inter_arrival_sender.h @@ -43,10 +43,12 @@ class NET_EXPORT_PRIVATE InterArrivalSender : public SendAlgorithmInterface { virtual void OnIncomingLoss(QuicTime ack_receive_time) OVERRIDE; - virtual void SentPacket(QuicTime sent_time, - QuicPacketSequenceNumber sequence_number, - QuicByteCount bytes, - Retransmission is_retransmit) OVERRIDE; + virtual bool SentPacket( + QuicTime sent_time, + QuicPacketSequenceNumber sequence_number, + QuicByteCount bytes, + Retransmission is_retransmit, + HasRetransmittableData has_retransmittable_data) OVERRIDE; virtual void AbandoningPacket(QuicPacketSequenceNumber sequence_number, QuicByteCount abandoned_bytes) OVERRIDE; diff --git a/chromium/net/quic/congestion_control/inter_arrival_sender_test.cc b/chromium/net/quic/congestion_control/inter_arrival_sender_test.cc index d0faca0f8f1..7392b1a3a67 100644 --- a/chromium/net/quic/congestion_control/inter_arrival_sender_test.cc +++ b/chromium/net/quic/congestion_control/inter_arrival_sender_test.cc @@ -41,7 +41,7 @@ class InterArrivalSenderTest : public ::testing::Test { bytes_in_packet, send_clock_.Now()); sender_.SentPacket(send_clock_.Now(), sequence_number_, bytes_in_packet, - NOT_RETRANSMISSION); + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA); sequence_number_++; } EXPECT_FALSE(sender_.TimeUntilSend(send_clock_.Now(), diff --git a/chromium/net/quic/congestion_control/quic_congestion_control_test.cc b/chromium/net/quic/congestion_control/quic_congestion_control_test.cc index 0051acab1f7..457538f4552 100644 --- a/chromium/net/quic/congestion_control/quic_congestion_control_test.cc +++ b/chromium/net/quic/congestion_control/quic_congestion_control_test.cc @@ -49,7 +49,8 @@ TEST_F(QuicCongestionControlTest, FixedRateSenderAPI) { clock_.Now()); EXPECT_TRUE(manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket(1, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION); + manager_->SentPacket(1, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); EXPECT_EQ(QuicTime::Delta::FromMilliseconds(40), manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE)); @@ -78,7 +79,8 @@ TEST_F(QuicCongestionControlTest, FixedRatePacing) { for (QuicPacketSequenceNumber i = 1; i <= 100; ++i) { EXPECT_TRUE(manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket(i, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION); + manager_->SentPacket(i, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); QuicTime::Delta advance_time = manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE); clock_.AdvanceTime(advance_time); @@ -108,10 +110,12 @@ TEST_F(QuicCongestionControlTest, Pacing) { for (QuicPacketSequenceNumber i = 1; i <= 100;) { EXPECT_TRUE(manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket(i++, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION); + manager_->SentPacket(i++, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); EXPECT_TRUE(manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket(i++, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION); + manager_->SentPacket(i++, clock_.Now(), kMaxPacketSize, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); QuicTime::Delta advance_time = manager_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE); clock_.AdvanceTime(advance_time); diff --git a/chromium/net/quic/congestion_control/quic_congestion_manager.cc b/chromium/net/quic/congestion_control/quic_congestion_manager.cc index ba6bab83ba3..ec519db7c93 100644 --- a/chromium/net/quic/congestion_control/quic_congestion_manager.cc +++ b/chromium/net/quic/congestion_control/quic_congestion_manager.cc @@ -48,18 +48,21 @@ QuicCongestionManager::~QuicCongestionManager() { STLDeleteValues(&packet_history_map_); } -void QuicCongestionManager::SentPacket(QuicPacketSequenceNumber sequence_number, - QuicTime sent_time, - QuicByteCount bytes, - Retransmission retransmission) { +void QuicCongestionManager::SentPacket( + QuicPacketSequenceNumber sequence_number, + QuicTime sent_time, + QuicByteCount bytes, + Retransmission retransmission, + HasRetransmittableData has_retransmittable_data) { DCHECK(!ContainsKey(pending_packets_, sequence_number)); - send_algorithm_->SentPacket(sent_time, sequence_number, bytes, - retransmission); - packet_history_map_[sequence_number] = - new class SendAlgorithmInterface::SentPacket(bytes, sent_time); - pending_packets_[sequence_number] = bytes; - CleanupPacketHistory(); + if (send_algorithm_->SentPacket(sent_time, sequence_number, bytes, + retransmission, has_retransmittable_data)) { + packet_history_map_[sequence_number] = + new class SendAlgorithmInterface::SentPacket(bytes, sent_time); + pending_packets_[sequence_number] = bytes; + CleanupPacketHistory(); + } } // Called when a packet is timed out. @@ -156,6 +159,23 @@ const QuicTime::Delta QuicCongestionManager::DefaultRetransmissionTime() { return QuicTime::Delta::FromMilliseconds(kDefaultRetransmissionTimeMs); } +// Ensures that the Delayed Ack timer is always set to a value lesser +// than the retransmission timer's minimum value (MinRTO). We want the +// delayed ack to get back to the QUIC peer before the sender's +// retransmission timer triggers. Since we do not know the +// reverse-path one-way delay, we assume equal delays for forward and +// reverse paths, and ensure that the timer is set to less than half +// of the MinRTO. +// There may be a value in making this delay adaptive with the help of +// the sender and a signaling mechanism -- if the sender uses a +// different MinRTO, we may get spurious retransmissions. May not have +// any benefits, but if the delayed ack becomes a significant source +// of (likely, tail) latency, then consider such a mechanism. + +const QuicTime::Delta QuicCongestionManager::DelayedAckTime() { + return QuicTime::Delta::FromMilliseconds(kMinRetransmissionTimeMs/2); +} + const QuicTime::Delta QuicCongestionManager::GetRetransmissionDelay( size_t unacked_packets_count, size_t number_retransmissions) { diff --git a/chromium/net/quic/congestion_control/quic_congestion_manager.h b/chromium/net/quic/congestion_control/quic_congestion_manager.h index 8bfa3c1d425..66303439ba4 100644 --- a/chromium/net/quic/congestion_control/quic_congestion_manager.h +++ b/chromium/net/quic/congestion_control/quic_congestion_manager.h @@ -1,7 +1,3 @@ -// Copyright (c) 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - // Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. @@ -50,7 +46,8 @@ class NET_EXPORT_PRIVATE QuicCongestionManager { virtual void SentPacket(QuicPacketSequenceNumber sequence_number, QuicTime sent_time, QuicByteCount bytes, - Retransmission retransmission); + Retransmission retransmission, + HasRetransmittableData has_retransmittable_data); // Called when a packet is timed out. virtual void AbandoningPacket(QuicPacketSequenceNumber sequence_number); @@ -85,6 +82,9 @@ class NET_EXPORT_PRIVATE QuicCongestionManager { const QuicTime::Delta DefaultRetransmissionTime(); + // Returns amount of time for delayed ack timer. + const QuicTime::Delta DelayedAckTime(); + const QuicTime::Delta GetRetransmissionDelay( size_t unacked_packets_count, size_t number_retransmissions); diff --git a/chromium/net/quic/congestion_control/quic_congestion_manager_test.cc b/chromium/net/quic/congestion_control/quic_congestion_manager_test.cc index 1cf44a2bdf8..80460f552b6 100644 --- a/chromium/net/quic/congestion_control/quic_congestion_manager_test.cc +++ b/chromium/net/quic/congestion_control/quic_congestion_manager_test.cc @@ -14,6 +14,7 @@ using testing::_; using testing::StrictMock; +using testing::Return; namespace net { namespace test { @@ -64,7 +65,8 @@ TEST_F(QuicCongestionManagerTest, Bandwidth) { clock_.AdvanceTime(advance_time); EXPECT_TRUE(manager_->TimeUntilSend( clock_.Now(), NOT_RETRANSMISSION, kIgnored, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket(i, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(i, clock_.Now(), 1000, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); // Ack the packet we sent. ack.received_info.largest_observed = i; manager_->OnIncomingAckFrame(ack, clock_.Now()); @@ -92,8 +94,8 @@ TEST_F(QuicCongestionManagerTest, BandwidthWith1SecondGap) { clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); EXPECT_TRUE(manager_->TimeUntilSend( clock_.Now(), NOT_RETRANSMISSION, kIgnored, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket( - sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(sequence_number, clock_.Now(), 1000, + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA); // Ack the packet we sent. ack.received_info.largest_observed = sequence_number; manager_->OnIncomingAckFrame(ack, clock_.Now()); @@ -118,7 +120,8 @@ TEST_F(QuicCongestionManagerTest, BandwidthWith1SecondGap) { for (int i = 1; i <= 150; ++i) { EXPECT_TRUE(manager_->TimeUntilSend( clock_.Now(), NOT_RETRANSMISSION, kIgnored, NOT_HANDSHAKE).IsZero()); - manager_->SentPacket(i + 100, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(i + 100, clock_.Now(), 1000, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); // Ack the packet we sent. ack.received_info.largest_observed = i + 100; @@ -141,11 +144,13 @@ TEST_F(QuicCongestionManagerTest, Rtt) { QuicPacketSequenceNumber sequence_number = 1; QuicTime::Delta expected_rtt = QuicTime::Delta::FromMilliseconds(15); - EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _, _)) + .Times(1).WillOnce(Return(true)); EXPECT_CALL(*send_algorithm, OnIncomingAck(sequence_number, _, expected_rtt)).Times(1); - manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(20)); QuicAckFrame ack; @@ -167,11 +172,13 @@ TEST_F(QuicCongestionManagerTest, RttWithInvalidDelta) { QuicPacketSequenceNumber sequence_number = 1; QuicTime::Delta expected_rtt = QuicTime::Delta::Infinite(); - EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _, _)) + .Times(1).WillOnce(Return(true)); EXPECT_CALL(*send_algorithm, OnIncomingAck(sequence_number, _, expected_rtt)).Times(1); - manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); QuicAckFrame ack; @@ -193,11 +200,13 @@ TEST_F(QuicCongestionManagerTest, RttInfiniteDelta) { QuicPacketSequenceNumber sequence_number = 1; QuicTime::Delta expected_rtt = QuicTime::Delta::Infinite(); - EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _, _)) + .Times(1).WillOnce(Return(true)); EXPECT_CALL(*send_algorithm, OnIncomingAck(sequence_number, _, expected_rtt)).Times(1); - manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); QuicAckFrame ack; @@ -218,11 +227,13 @@ TEST_F(QuicCongestionManagerTest, RttZeroDelta) { QuicPacketSequenceNumber sequence_number = 1; QuicTime::Delta expected_rtt = QuicTime::Delta::FromMilliseconds(10); - EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _)).Times(1); - EXPECT_CALL(*send_algorithm, - OnIncomingAck(sequence_number, _, expected_rtt)).Times(1); + EXPECT_CALL(*send_algorithm, SentPacket(_, _, _, _, _)) + .Times(1).WillOnce(Return(true)); + EXPECT_CALL(*send_algorithm, OnIncomingAck(sequence_number, _, expected_rtt)) + .Times(1); - manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION); + manager_->SentPacket(sequence_number, clock_.Now(), 1000, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA); clock_.AdvanceTime(expected_rtt); QuicAckFrame ack; diff --git a/chromium/net/quic/congestion_control/send_algorithm_interface.cc b/chromium/net/quic/congestion_control/send_algorithm_interface.cc index ce24a00b1ac..90399e7d05b 100644 --- a/chromium/net/quic/congestion_control/send_algorithm_interface.cc +++ b/chromium/net/quic/congestion_control/send_algorithm_interface.cc @@ -14,7 +14,7 @@ const bool kUseReno = false; // TODO(ianswett): Increase the max congestion window once the RTO logic is // improved, particularly in cases when RTT is larger than the RTO. b/10075719 // Maximum number of outstanding packets for tcp. -const QuicTcpCongestionWindow kMaxTcpCongestionWindow = 50; +const QuicTcpCongestionWindow kMaxTcpCongestionWindow = 100; // Factory for send side congestion control algorithm. SendAlgorithmInterface* SendAlgorithmInterface::Create( diff --git a/chromium/net/quic/congestion_control/send_algorithm_interface.h b/chromium/net/quic/congestion_control/send_algorithm_interface.h index 8896b2b06df..c29f22545d0 100644 --- a/chromium/net/quic/congestion_control/send_algorithm_interface.h +++ b/chromium/net/quic/congestion_control/send_algorithm_interface.h @@ -55,11 +55,15 @@ class NET_EXPORT_PRIVATE SendAlgorithmInterface { virtual void OnIncomingLoss(QuicTime ack_receive_time) = 0; // Inform that we sent x bytes to the wire, and if that was a retransmission. + // Returns true if the packet should be tracked by the congestion manager, + // false otherwise. This is used by implementations such as tcp_cubic_sender + // that do not count outgoing ACK packets against the congestion window. // Note: this function must be called for every packet sent to the wire. - virtual void SentPacket(QuicTime sent_time, + virtual bool SentPacket(QuicTime sent_time, QuicPacketSequenceNumber sequence_number, QuicByteCount bytes, - Retransmission is_retransmission) = 0; + Retransmission is_retransmission, + HasRetransmittableData is_retransmittable) = 0; // Called when a packet is timed out. virtual void AbandoningPacket(QuicPacketSequenceNumber sequence_number, diff --git a/chromium/net/quic/congestion_control/tcp_cubic_sender.cc b/chromium/net/quic/congestion_control/tcp_cubic_sender.cc index 1e98c12cbf3..52c910e8592 100644 --- a/chromium/net/quic/congestion_control/tcp_cubic_sender.cc +++ b/chromium/net/quic/congestion_control/tcp_cubic_sender.cc @@ -65,6 +65,7 @@ void TcpCubicSender::OnIncomingQuicCongestionFeedbackFrame( void TcpCubicSender::OnIncomingAck( QuicPacketSequenceNumber acked_sequence_number, QuicByteCount acked_bytes, QuicTime::Delta rtt) { + DCHECK_GE(bytes_in_flight_, acked_bytes); bytes_in_flight_ -= acked_bytes; CongestionAvoidance(acked_sequence_number); AckAccounting(rtt); @@ -93,10 +94,16 @@ void TcpCubicSender::OnIncomingLoss(QuicTime /*ack_receive_time*/) { DLOG(INFO) << "Incoming loss; congestion window:" << congestion_window_; } -void TcpCubicSender::SentPacket(QuicTime /*sent_time*/, +bool TcpCubicSender::SentPacket(QuicTime /*sent_time*/, QuicPacketSequenceNumber sequence_number, QuicByteCount bytes, - Retransmission is_retransmission) { + Retransmission is_retransmission, + HasRetransmittableData is_retransmittable) { + // Only update bytes_in_flight_ for data packets. + if (is_retransmittable != HAS_RETRANSMITTABLE_DATA) { + return false; + } + bytes_in_flight_ += bytes; if (is_retransmission == NOT_RETRANSMISSION && update_end_sequence_number_) { end_sequence_number_ = sequence_number; @@ -105,10 +112,12 @@ void TcpCubicSender::SentPacket(QuicTime /*sent_time*/, DLOG(INFO) << "Stop update end sequence number @" << sequence_number; } } + return true; } void TcpCubicSender::AbandoningPacket(QuicPacketSequenceNumber sequence_number, QuicByteCount abandoned_bytes) { + DCHECK_GE(bytes_in_flight_, abandoned_bytes); bytes_in_flight_ -= abandoned_bytes; } @@ -195,7 +204,7 @@ void TcpCubicSender::CongestionAvoidance(QuicPacketSequenceNumber ack) { } // congestion_window_cnt is the number of acks since last change of snd_cwnd if (congestion_window_ < max_tcp_congestion_window_) { - // TCP slow start, exponentail growth, increase by one for each ACK. + // TCP slow start, exponential growth, increase by one for each ACK. congestion_window_++; } DLOG(INFO) << "Slow start; congestion window:" << congestion_window_; @@ -211,8 +220,9 @@ void TcpCubicSender::CongestionAvoidance(QuicPacketSequenceNumber ack) { } DLOG(INFO) << "Reno; congestion window:" << congestion_window_; } else { - congestion_window_ = cubic_.CongestionWindowAfterAck(congestion_window_, - delay_min_); + congestion_window_ = std::min( + max_tcp_congestion_window_, + cubic_.CongestionWindowAfterAck(congestion_window_, delay_min_)); DLOG(INFO) << "Cubic; congestion window:" << congestion_window_; } } diff --git a/chromium/net/quic/congestion_control/tcp_cubic_sender.h b/chromium/net/quic/congestion_control/tcp_cubic_sender.h index c22813a2cfd..db829c29bba 100644 --- a/chromium/net/quic/congestion_control/tcp_cubic_sender.h +++ b/chromium/net/quic/congestion_control/tcp_cubic_sender.h @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// TCP cubic send side congestion algorithm, emulates the behaviour of +// TCP cubic send side congestion algorithm, emulates the behavior of // TCP cubic. #ifndef NET_QUIC_CONGESTION_CONTROL_TCP_CUBIC_SENDER_H_ @@ -41,10 +41,12 @@ class NET_EXPORT_PRIVATE TcpCubicSender : public SendAlgorithmInterface { QuicByteCount acked_bytes, QuicTime::Delta rtt) OVERRIDE; virtual void OnIncomingLoss(QuicTime ack_receive_time) OVERRIDE; - virtual void SentPacket(QuicTime sent_time, - QuicPacketSequenceNumber sequence_number, - QuicByteCount bytes, - Retransmission is_retransmission) OVERRIDE; + virtual bool SentPacket( + QuicTime sent_time, + QuicPacketSequenceNumber sequence_number, + QuicByteCount bytes, + Retransmission is_retransmission, + HasRetransmittableData is_retransmittable) OVERRIDE; virtual void AbandoningPacket(QuicPacketSequenceNumber sequence_number, QuicByteCount abandoned_bytes) OVERRIDE; virtual QuicTime::Delta TimeUntilSend( diff --git a/chromium/net/quic/congestion_control/tcp_cubic_sender_test.cc b/chromium/net/quic/congestion_control/tcp_cubic_sender_test.cc index e9f88937426..c7046fcb2f7 100644 --- a/chromium/net/quic/congestion_control/tcp_cubic_sender_test.cc +++ b/chromium/net/quic/congestion_control/tcp_cubic_sender_test.cc @@ -13,13 +13,15 @@ namespace net { namespace test { const uint32 kDefaultWindowTCP = 10 * kMaxPacketSize; -const QuicByteCount kNoNBytesInFlight = 0; +// TODO(ianswett): Remove 10000 once b/10075719 is fixed. +const QuicTcpCongestionWindow kDefaultMaxCongestionWindowTCP = 10000; class TcpCubicSenderPeer : public TcpCubicSender { public: - // TODO(ianswett): Remove 10000 once b/10075719 is fixed. - TcpCubicSenderPeer(const QuicClock* clock, bool reno) - : TcpCubicSender(clock, reno, 10000) { + TcpCubicSenderPeer(const QuicClock* clock, + bool reno, + QuicTcpCongestionWindow max_tcp_congestion_window) + : TcpCubicSender(clock, reno, max_tcp_congestion_window) { } using TcpCubicSender::AvailableCongestionWindow; using TcpCubicSender::CongestionWindow; @@ -29,12 +31,13 @@ class TcpCubicSenderPeer : public TcpCubicSender { class TcpCubicSenderTest : public ::testing::Test { protected: TcpCubicSenderTest() - : rtt_(QuicTime::Delta::FromMilliseconds(60)), - one_ms_(QuicTime::Delta::FromMilliseconds(1)), - sender_(new TcpCubicSenderPeer(&clock_, true)), - receiver_(new TcpReceiver()), - sequence_number_(1), - acked_sequence_number_(0) { + : rtt_(QuicTime::Delta::FromMilliseconds(60)), + one_ms_(QuicTime::Delta::FromMilliseconds(1)), + sender_(new TcpCubicSenderPeer(&clock_, true, + kDefaultMaxCongestionWindowTCP)), + receiver_(new TcpReceiver()), + sequence_number_(1), + acked_sequence_number_(0) { } void SendAvailableCongestionWindow() { @@ -42,7 +45,7 @@ class TcpCubicSenderTest : public ::testing::Test { while (bytes_to_send > 0) { QuicByteCount bytes_in_packet = std::min(kMaxPacketSize, bytes_to_send); sender_->SentPacket(clock_.Now(), sequence_number_++, bytes_in_packet, - NOT_RETRANSMISSION); + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA); bytes_to_send -= bytes_in_packet; if (bytes_to_send > 0) { EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), NOT_RETRANSMISSION, @@ -87,7 +90,7 @@ TEST_F(TcpCubicSenderTest, SimpleSender) { // And that window is un-affected. EXPECT_EQ(kDefaultWindowTCP, sender_->AvailableCongestionWindow()); - // A retransmitt should always retun 0. + // A retransmit should always return 0. EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), IS_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); } @@ -200,7 +203,7 @@ TEST_F(TcpCubicSenderTest, SlowStartPacketLoss) { EXPECT_EQ(expected_congestion_window, sender_->CongestionWindow()); // Testing Reno phase. - // We need to ack half of the pending packet before we can send agin. + // We need to ack half of the pending packet before we can send again. int number_of_packets_in_window = expected_congestion_window / kMaxPacketSize; AckNPackets(number_of_packets_in_window); EXPECT_EQ(expected_congestion_window, sender_->CongestionWindow()); @@ -232,8 +235,8 @@ TEST_F(TcpCubicSenderTest, RetransmissionDelay) { sender_->AckAccounting(QuicTime::Delta::FromMilliseconds(kRttMs)); // Initial value is to set the median deviation to half of the initial - // rtt, the median in then multiplied by a factor of 4 and finaly the - // smoothed rtt is added which is the inital rtt. + // rtt, the median in then multiplied by a factor of 4 and finally the + // smoothed rtt is added which is the initial rtt. QuicTime::Delta expected_delay = QuicTime::Delta::FromMilliseconds(kRttMs + kRttMs / 2 * 4); EXPECT_EQ(expected_delay, sender_->RetransmissionDelay()); @@ -252,5 +255,119 @@ TEST_F(TcpCubicSenderTest, RetransmissionDelay) { sender_->RetransmissionDelay().ToMilliseconds(), 1); } + +TEST_F(TcpCubicSenderTest, SlowStartMaxCongestionWindow) { + const QuicTcpCongestionWindow kMaxCongestionWindowTCP = 50; + const int kNumberOfAck = 100; + sender_.reset( + new TcpCubicSenderPeer(&clock_, false, kMaxCongestionWindowTCP)); + + QuicCongestionFeedbackFrame feedback; + // At startup make sure we can send. + EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); + // Get default QuicCongestionFeedbackFrame from receiver. + ASSERT_TRUE(receiver_->GenerateCongestionFeedback(&feedback)); + sender_->OnIncomingQuicCongestionFeedbackFrame(feedback, clock_.Now(), + not_used_); + // Make sure we can send. + EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); + + for (int i = 0; i < kNumberOfAck; ++i) { + // Send our full congestion window. + SendAvailableCongestionWindow(); + AckNPackets(2); + } + QuicByteCount expected_congestion_window = + kMaxCongestionWindowTCP * kMaxPacketSize; + EXPECT_EQ(expected_congestion_window, sender_->CongestionWindow()); +} + +TEST_F(TcpCubicSenderTest, TcpRenoMaxCongestionWindow) { + const QuicTcpCongestionWindow kMaxCongestionWindowTCP = 50; + const int kNumberOfAck = 1000; + sender_.reset( + new TcpCubicSenderPeer(&clock_, true, kMaxCongestionWindowTCP)); + + QuicCongestionFeedbackFrame feedback; + // At startup make sure we can send. + EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); + // Get default QuicCongestionFeedbackFrame from receiver. + ASSERT_TRUE(receiver_->GenerateCongestionFeedback(&feedback)); + sender_->OnIncomingQuicCongestionFeedbackFrame(feedback, clock_.Now(), + not_used_); + // Make sure we can send. + EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); + + SendAvailableCongestionWindow(); + AckNPackets(2); + // Make sure we fall out of slow start. + sender_->OnIncomingLoss(clock_.Now()); + + for (int i = 0; i < kNumberOfAck; ++i) { + // Send our full congestion window. + SendAvailableCongestionWindow(); + AckNPackets(2); + } + + QuicByteCount expected_congestion_window = + kMaxCongestionWindowTCP * kMaxPacketSize; + EXPECT_EQ(expected_congestion_window, sender_->CongestionWindow()); +} + +TEST_F(TcpCubicSenderTest, TcpCubicMaxCongestionWindow) { + const QuicTcpCongestionWindow kMaxCongestionWindowTCP = 50; + const int kNumberOfAck = 1000; + sender_.reset( + new TcpCubicSenderPeer(&clock_, false, kMaxCongestionWindowTCP)); + + QuicCongestionFeedbackFrame feedback; + // At startup make sure we can send. + EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); + // Get default QuicCongestionFeedbackFrame from receiver. + ASSERT_TRUE(receiver_->GenerateCongestionFeedback(&feedback)); + sender_->OnIncomingQuicCongestionFeedbackFrame(feedback, clock_.Now(), + not_used_); + // Make sure we can send. + EXPECT_TRUE(sender_->TimeUntilSend(clock_.Now(), + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, NOT_HANDSHAKE).IsZero()); + + SendAvailableCongestionWindow(); + AckNPackets(2); + // Make sure we fall out of slow start. + sender_->OnIncomingLoss(clock_.Now()); + + for (int i = 0; i < kNumberOfAck; ++i) { + // Send our full congestion window. + SendAvailableCongestionWindow(); + AckNPackets(2); + } + + QuicByteCount expected_congestion_window = + kMaxCongestionWindowTCP * kMaxPacketSize; + EXPECT_EQ(expected_congestion_window, sender_->CongestionWindow()); +} + +TEST_F(TcpCubicSenderTest, CongestionWindowNotAffectedByAcks) { + QuicByteCount congestion_window = sender_->AvailableCongestionWindow(); + + // Send a packet with no retransmittable data, and ensure that the congestion + // window doesn't change. + QuicByteCount bytes_in_packet = std::min(kMaxPacketSize, congestion_window); + sender_->SentPacket(clock_.Now(), sequence_number_++, bytes_in_packet, + NOT_RETRANSMISSION, NO_RETRANSMITTABLE_DATA); + EXPECT_EQ(congestion_window, sender_->AvailableCongestionWindow()); + + // Send a data packet with retransmittable data, and ensure that the + // congestion window has shrunk. + sender_->SentPacket(clock_.Now(), sequence_number_++, bytes_in_packet, + NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA); + EXPECT_GT(congestion_window, sender_->AvailableCongestionWindow()); +} + } // namespace test } // namespace net diff --git a/chromium/net/quic/crypto/aes_128_gcm_12_encrypter.h b/chromium/net/quic/crypto/aes_128_gcm_12_encrypter.h index 451f84df6f8..ca9a2b1fca6 100644 --- a/chromium/net/quic/crypto/aes_128_gcm_12_encrypter.h +++ b/chromium/net/quic/crypto/aes_128_gcm_12_encrypter.h @@ -61,8 +61,6 @@ class NET_EXPORT_PRIVATE Aes128Gcm12Encrypter : public QuicEncrypter { unsigned char key_[16]; // The nonce prefix. unsigned char nonce_prefix_[4]; - // last_seq_num_ is the last sequence number observed. - QuicPacketSequenceNumber last_seq_num_; #if defined(USE_OPENSSL) ScopedEVPCipherCtx ctx_; diff --git a/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_nss.cc b/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_nss.cc index 1cd3540c884..ae6adab462b 100644 --- a/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_nss.cc +++ b/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_nss.cc @@ -250,7 +250,7 @@ SECStatus My_Encrypt(PK11SymKey* key, } // namespace -Aes128Gcm12Encrypter::Aes128Gcm12Encrypter() : last_seq_num_(0) { +Aes128Gcm12Encrypter::Aes128Gcm12Encrypter() { ignore_result(g_gcm_support_checker.Get()); } @@ -350,12 +350,8 @@ QuicData* Aes128Gcm12Encrypter::EncryptPacket( size_t ciphertext_size = GetCiphertextSize(plaintext.length()); scoped_ptr<char[]> ciphertext(new char[ciphertext_size]); - if (last_seq_num_ != 0 && sequence_number <= last_seq_num_) { - DLOG(FATAL) << "Sequence numbers regressed"; - return NULL; - } - last_seq_num_ = sequence_number; - + // TODO(ianswett): Introduce a check to ensure that we don't encrypt with the + // same sequence number twice. uint8 nonce[kNoncePrefixSize + sizeof(sequence_number)]; COMPILE_ASSERT(sizeof(nonce) == kAESNonceSize, bad_sequence_number_size); memcpy(nonce, nonce_prefix_, kNoncePrefixSize); diff --git a/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_openssl.cc b/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_openssl.cc index 79d0ec1a8a0..166fd55cb7c 100644 --- a/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_openssl.cc +++ b/chromium/net/quic/crypto/aes_128_gcm_12_encrypter_openssl.cc @@ -21,7 +21,7 @@ const size_t kAESNonceSize = 12; } // namespace -Aes128Gcm12Encrypter::Aes128Gcm12Encrypter() : last_seq_num_(0) {} +Aes128Gcm12Encrypter::Aes128Gcm12Encrypter() {} Aes128Gcm12Encrypter::~Aes128Gcm12Encrypter() {} @@ -118,12 +118,8 @@ QuicData* Aes128Gcm12Encrypter::EncryptPacket( size_t ciphertext_size = GetCiphertextSize(plaintext.length()); scoped_ptr<char[]> ciphertext(new char[ciphertext_size]); - if (last_seq_num_ != 0 && sequence_number <= last_seq_num_) { - DLOG(FATAL) << "Sequence numbers regressed"; - return NULL; - } - last_seq_num_ = sequence_number; - + // TODO(ianswett): Introduce a check to ensure that we don't encrypt with the + // same sequence number twice. uint8 nonce[kNoncePrefixSize + sizeof(sequence_number)]; COMPILE_ASSERT(sizeof(nonce) == kAESNonceSize, bad_sequence_number_size); memcpy(nonce, nonce_prefix_, kNoncePrefixSize); diff --git a/chromium/net/quic/crypto/crypto_framer.h b/chromium/net/quic/crypto/crypto_framer.h index b070c66e277..ea69f3a5724 100644 --- a/chromium/net/quic/crypto/crypto_framer.h +++ b/chromium/net/quic/crypto/crypto_framer.h @@ -84,8 +84,6 @@ class NET_EXPORT_PRIVATE CryptoFramer { size_t pad_length, uint32* end_offset); - void set_error(QuicErrorCode error) { error_ = error; } - // Represents the current state of the parsing state machine. enum CryptoFramerState { STATE_READING_TAG, diff --git a/chromium/net/quic/crypto/crypto_handshake.cc b/chromium/net/quic/crypto/crypto_handshake.cc index d6a76f9f511..51465b75c4b 100644 --- a/chromium/net/quic/crypto/crypto_handshake.cc +++ b/chromium/net/quic/crypto/crypto_handshake.cc @@ -84,11 +84,6 @@ void CryptoHandshakeMessage::MarkDirty() { serialized_.reset(); } -void CryptoHandshakeMessage::Insert(QuicTagValueMap::const_iterator begin, - QuicTagValueMap::const_iterator end) { - tag_value_map_.insert(begin, end); -} - void CryptoHandshakeMessage::SetTaglist(QuicTag tag, ...) { // Warning, if sizeof(QuicTag) > sizeof(int) then this function will break // because the terminating 0 will only be promoted to int. @@ -326,8 +321,7 @@ string CryptoHandshakeMessage::DebugStringInternal(size_t indent) const { } QuicCryptoNegotiatedParameters::QuicCryptoNegotiatedParameters() - : version(0), - key_exchange(0), + : key_exchange(0), aead(0) { } @@ -470,6 +464,12 @@ void QuicCryptoClientConfig::CachedState::SetProof(const vector<string>& certs, server_config_sig_ = signature.as_string(); } +void QuicCryptoClientConfig::CachedState::ClearProof() { + SetProofInvalid(); + certs_.clear(); + server_config_sig_.clear(); +} + void QuicCryptoClientConfig::CachedState::SetProofValid() { server_config_valid_ = true; } @@ -758,6 +758,10 @@ QuicErrorCode QuicCryptoClientConfig::FillClientHello( 0 /* sequence number */, StringPiece() /* associated data */, cetv_plaintext.AsStringPiece())); + if (!cetv_ciphertext.get()) { + *error_details = "Packet encryption failed"; + return QUIC_ENCRYPTION_FAILURE; + } out->SetStringPiece(kCETV, cetv_ciphertext->AsStringPiece()); out->MarkDirty(); @@ -788,9 +792,9 @@ QuicErrorCode QuicCryptoClientConfig::FillClientHello( } QuicErrorCode QuicCryptoClientConfig::ProcessRejection( - CachedState* cached, const CryptoHandshakeMessage& rej, QuicWallTime now, + CachedState* cached, QuicCryptoNegotiatedParameters* out_params, string* error_details) { DCHECK(error_details != NULL); @@ -822,8 +826,9 @@ QuicErrorCode QuicCryptoClientConfig::ProcessRejection( } StringPiece proof, cert_bytes; - if (rej.GetStringPiece(kPROF, &proof) && - rej.GetStringPiece(kCertificateTag, &cert_bytes)) { + bool has_proof = rej.GetStringPiece(kPROF, &proof); + bool has_cert = rej.GetStringPiece(kCertificateTag, &cert_bytes); + if (has_proof && has_cert) { vector<string> certs; if (!CertCompressor::DecompressChain(cert_bytes, out_params->cached_certs, common_cert_sets, &certs)) { @@ -832,6 +837,17 @@ QuicErrorCode QuicCryptoClientConfig::ProcessRejection( } cached->SetProof(certs, proof); + } else { + cached->ClearProof(); + if (has_proof && !has_cert) { + *error_details = "Certificate missing"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } + + if (!has_proof && has_cert) { + *error_details = "Proof missing"; + return QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER; + } } return QUIC_NO_ERROR; @@ -840,6 +856,7 @@ QuicErrorCode QuicCryptoClientConfig::ProcessRejection( QuicErrorCode QuicCryptoClientConfig::ProcessServerHello( const CryptoHandshakeMessage& server_hello, QuicGuid guid, + CachedState* cached, QuicCryptoNegotiatedParameters* out_params, string* error_details) { DCHECK(error_details != NULL); @@ -849,6 +866,12 @@ QuicErrorCode QuicCryptoClientConfig::ProcessServerHello( return QUIC_INVALID_CRYPTO_MESSAGE_TYPE; } + // Learn about updated source address tokens. + StringPiece token; + if (server_hello.GetStringPiece(kSourceAddressTokenTag, &token)) { + cached->set_source_address_token(token); + } + // TODO(agl): // learn about updated SCFGs. diff --git a/chromium/net/quic/crypto/crypto_handshake.h b/chromium/net/quic/crypto/crypto_handshake.h index fdc92a0fc38..cec393f7c77 100644 --- a/chromium/net/quic/crypto/crypto_handshake.h +++ b/chromium/net/quic/crypto/crypto_handshake.h @@ -73,9 +73,6 @@ class NET_EXPORT_PRIVATE CryptoHandshakeMessage { const QuicTagValueMap& tag_value_map() const { return tag_value_map_; } - void Insert(QuicTagValueMap::const_iterator begin, - QuicTagValueMap::const_iterator end); - // SetTaglist sets an element with the given tag to contain a list of tags, // passed as varargs. The argument list must be terminated with a 0 element. void SetTaglist(QuicTag tag, ...); @@ -160,7 +157,6 @@ struct NET_EXPORT_PRIVATE QuicCryptoNegotiatedParameters { QuicCryptoNegotiatedParameters(); ~QuicCryptoNegotiatedParameters(); - uint16 version; QuicTag key_exchange; QuicTag aead; std::string initial_premaster_secret; @@ -266,6 +262,9 @@ class NET_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { void SetProof(const std::vector<std::string>& certs, base::StringPiece signature); + // Clears the certificate chain and signature and invalidates the proof. + void ClearProof(); + // SetProofValid records that the certificate chain and signature have been // validated and that it's safe to assume that the server is legitimate. // (Note: this does not check the chain or signature.) @@ -353,18 +352,20 @@ class NET_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { // state about a future handshake (i.e. an nonce value from the server), then // it will be saved in |out_params|. |now| is used to judge whether the // server config in the rejection message has expired. - QuicErrorCode ProcessRejection(CachedState* cached, - const CryptoHandshakeMessage& rej, + QuicErrorCode ProcessRejection(const CryptoHandshakeMessage& rej, QuicWallTime now, + CachedState* cached, QuicCryptoNegotiatedParameters* out_params, std::string* error_details); - // ProcessServerHello processes the message in |server_hello|, writes the - // negotiated parameters to |out_params| and returns QUIC_NO_ERROR. If - // |server_hello| is unacceptable then it puts an error message in - // |error_details| and returns an error code. + // ProcessServerHello processes the message in |server_hello|, updates the + // cached information about that server, writes the negotiated parameters to + // |out_params| and returns QUIC_NO_ERROR. If |server_hello| is unacceptable + // then it puts an error message in |error_details| and returns an error + // code. QuicErrorCode ProcessServerHello(const CryptoHandshakeMessage& server_hello, QuicGuid guid, + CachedState* cached, QuicCryptoNegotiatedParameters* out_params, std::string* error_details); diff --git a/chromium/net/quic/crypto/crypto_protocol.h b/chromium/net/quic/crypto/crypto_protocol.h index 586569a6a0a..4580fce4bb4 100644 --- a/chromium/net/quic/crypto/crypto_protocol.h +++ b/chromium/net/quic/crypto/crypto_protocol.h @@ -59,7 +59,6 @@ const QuicTag kCHID = TAG('C', 'H', 'I', 'D'); // Channel ID. // Client hello tags const QuicTag kVERS = TAG('V', 'E', 'R', 'S'); // Version const QuicTag kNONC = TAG('N', 'O', 'N', 'C'); // The client's nonce -const QuicTag kSSID = TAG('S', 'S', 'I', 'D'); // Session ID const QuicTag kKEXS = TAG('K', 'E', 'X', 'S'); // Key exchange methods const QuicTag kAEAD = TAG('A', 'E', 'A', 'D'); // Authenticated // encryption algorithms diff --git a/chromium/net/quic/crypto/crypto_server_config.cc b/chromium/net/quic/crypto/crypto_server_config.cc index f270ddeb31a..89cea42862d 100644 --- a/chromium/net/quic/crypto/crypto_server_config.cc +++ b/chromium/net/quic/crypto/crypto_server_config.cc @@ -11,6 +11,7 @@ #include "base/strings/string_number_conversions.h" #include "crypto/hkdf.h" #include "crypto/secure_hash.h" +#include "net/base/net_util.h" #include "net/quic/crypto/aes_128_gcm_12_decrypter.h" #include "net/quic/crypto/aes_128_gcm_12_encrypter.h" #include "net/quic/crypto/cert_compressor.h" @@ -56,6 +57,7 @@ QuicCryptoServerConfig::QuicCryptoServerConfig( next_config_promotion_time_(QuicWallTime::Zero()), strike_register_lock_(), server_nonce_strike_register_lock_(), + strike_register_no_startup_period_(false), strike_register_max_entries_(1 << 10), strike_register_window_secs_(600), source_address_token_future_secs_(3600), @@ -304,7 +306,6 @@ struct ClientHelloInfo { QuicErrorCode QuicCryptoServerConfig::ProcessClientHello( const CryptoHandshakeMessage& client_hello, - QuicVersion version, QuicGuid guid, const IPEndPoint& client_ip, const QuicClock* clock, @@ -359,8 +360,7 @@ QuicErrorCode QuicCryptoServerConfig::ProcessClientHello( !info.client_nonce_well_formed || !info.unique || !requested_config.get()) { - BuildRejection(version, primary_config.get(), client_hello, info, rand, - out); + BuildRejection(primary_config.get(), client_hello, info, rand, out); return QUIC_NO_ERROR; } @@ -636,6 +636,8 @@ QuicErrorCode QuicCryptoServerConfig::EvaluateClientHello( static_cast<uint32>(info->now.ToUNIXSeconds()), strike_register_window_secs_, orbit, + strike_register_no_startup_period_ ? + StrikeRegister::NO_STARTUP_PERIOD_NEEDED : StrikeRegister::DENY_REQUESTS_AT_STARTUP)); } @@ -664,7 +666,6 @@ QuicErrorCode QuicCryptoServerConfig::EvaluateClientHello( } void QuicCryptoServerConfig::BuildRejection( - QuicVersion version, const scoped_refptr<Config>& config, const CryptoHandshakeMessage& client_hello, const ClientHelloInfo& info, @@ -708,9 +709,8 @@ void QuicCryptoServerConfig::BuildRejection( const vector<string>* certs; string signature; - if (!proof_source_->GetProof(version, info.sni.as_string(), - config->serialized, x509_ecdsa_supported, - &certs, &signature)) { + if (!proof_source_->GetProof(info.sni.as_string(), config->serialized, + x509_ecdsa_supported, &certs, &signature)) { return; } @@ -908,6 +908,12 @@ void QuicCryptoServerConfig::set_replay_protection(bool on) { replay_protection_ = on; } +void QuicCryptoServerConfig::set_strike_register_no_startup_period() { + base::AutoLock auto_lock(strike_register_lock_); + DCHECK(!strike_register_.get()); + strike_register_no_startup_period_ = true; +} + void QuicCryptoServerConfig::set_strike_register_max_entries( uint32 max_entries) { base::AutoLock locker(strike_register_lock_); @@ -949,7 +955,7 @@ string QuicCryptoServerConfig::NewSourceAddressToken( QuicRandom* rand, QuicWallTime now) const { SourceAddressToken source_address_token; - source_address_token.set_ip(ip.ToString()); + source_address_token.set_ip(IPAddressToPackedString(ip.address())); source_address_token.set_timestamp(now.ToUNIXSeconds()); return source_address_token_boxer_.Box( @@ -972,7 +978,7 @@ bool QuicCryptoServerConfig::ValidateSourceAddressToken( return false; } - if (source_address_token.ip() != ip.ToString()) { + if (source_address_token.ip() != IPAddressToPackedString(ip.address())) { // It's for a different IP address. return false; } diff --git a/chromium/net/quic/crypto/crypto_server_config.h b/chromium/net/quic/crypto/crypto_server_config.h index 364c200a149..4255d228618 100644 --- a/chromium/net/quic/crypto/crypto_server_config.h +++ b/chromium/net/quic/crypto/crypto_server_config.h @@ -116,8 +116,6 @@ class NET_EXPORT_PRIVATE QuicCryptoServerConfig { // an error code is returned. // // client_hello: the incoming client hello message. - // version: the QUIC version for the connection. TODO(wtc): Remove once - // QUIC_VERSION_7 and before are removed. // guid: the GUID for the connection, which is used in key derivation. // client_ip: the IP address of the client, which is used to generate and // validate source-address tokens. @@ -129,7 +127,6 @@ class NET_EXPORT_PRIVATE QuicCryptoServerConfig { // out: the resulting handshake message (either REJ or SHLO) // error_details: used to store a string describing any error. QuicErrorCode ProcessClientHello(const CryptoHandshakeMessage& client_hello, - QuicVersion version, QuicGuid guid, const IPEndPoint& client_ip, const QuicClock* clock, @@ -155,6 +152,10 @@ class NET_EXPORT_PRIVATE QuicCryptoServerConfig { // request to be processed twice. void set_replay_protection(bool on); + // set_strike_register_no_startup_period configures the strike register to + // not have a startup period. + void set_strike_register_no_startup_period(); + // set_strike_register_max_entries sets the maximum number of entries that // the internal strike register will hold. If the strike register fills up // then the oldest entries (by the client's clock) will be dropped. @@ -262,7 +263,6 @@ class NET_EXPORT_PRIVATE QuicCryptoServerConfig { // BuildRejection sets |out| to be a REJ message in reply to |client_hello|. void BuildRejection( - QuicVersion version, const scoped_refptr<Config>& config, const CryptoHandshakeMessage& client_hello, const ClientHelloInfo& info, @@ -351,6 +351,7 @@ class NET_EXPORT_PRIVATE QuicCryptoServerConfig { // These fields store configuration values. See the comments for their // respective setter functions. + bool strike_register_no_startup_period_; uint32 strike_register_max_entries_; uint32 strike_register_window_secs_; uint32 source_address_token_future_secs_; diff --git a/chromium/net/quic/crypto/crypto_server_test.cc b/chromium/net/quic/crypto/crypto_server_test.cc index 6744d12e5e0..b2cdf820c34 100644 --- a/chromium/net/quic/crypto/crypto_server_test.cc +++ b/chromium/net/quic/crypto/crypto_server_test.cc @@ -72,8 +72,8 @@ class CryptoServerTest : public ::testing::Test { void ShouldSucceed(const CryptoHandshakeMessage& message) { string error_details; QuicErrorCode error = config_.ProcessClientHello( - message, QuicVersionMax(), 1 /* GUID */, addr_, - &clock_, rand_, ¶ms_, &out_, &error_details); + message, 1 /* GUID */, addr_, &clock_, + rand_, ¶ms_, &out_, &error_details); ASSERT_EQ(error, QUIC_NO_ERROR) << "Message failed with error " << error_details << ": " @@ -84,8 +84,8 @@ class CryptoServerTest : public ::testing::Test { const CryptoHandshakeMessage& message) { string error_details; QuicErrorCode error = config_.ProcessClientHello( - message, QuicVersionMax(), 1 /* GUID */, addr_, - &clock_, rand_, ¶ms_, &out_, &error_details); + message, 1 /* GUID */, addr_, &clock_, + rand_, ¶ms_, &out_, &error_details); ASSERT_NE(error, QUIC_NO_ERROR) << "Message didn't fail: " << message.DebugString(); diff --git a/chromium/net/quic/crypto/proof_source.h b/chromium/net/quic/crypto/proof_source.h index ba5087b0f61..4482dd99561 100644 --- a/chromium/net/quic/crypto/proof_source.h +++ b/chromium/net/quic/crypto/proof_source.h @@ -9,7 +9,6 @@ #include <vector> #include "net/base/net_export.h" -#include "net/quic/quic_protocol.h" namespace net { @@ -28,9 +27,6 @@ class NET_EXPORT_PRIVATE ProofSource { // // The signature uses SHA-256 as the hash function when the key is ECDSA. // - // |version| is the QUIC version for the connection. TODO(wtc): Remove once - // QUIC_VERSION_7 and before are removed. - // // If |ecdsa_ok| is true, the signature may use an ECDSA key. Otherwise, the // signature must use an RSA key. // @@ -49,8 +45,7 @@ class NET_EXPORT_PRIVATE ProofSource { // used. // // This function may be called concurrently. - virtual bool GetProof(QuicVersion version, - const std::string& hostname, + virtual bool GetProof(const std::string& hostname, const std::string& server_config, bool ecdsa_ok, const std::vector<std::string>** out_certs, diff --git a/chromium/net/quic/crypto/proof_source_chromium.cc b/chromium/net/quic/crypto/proof_source_chromium.cc index 4c1fe263b62..75226313818 100644 --- a/chromium/net/quic/crypto/proof_source_chromium.cc +++ b/chromium/net/quic/crypto/proof_source_chromium.cc @@ -12,8 +12,7 @@ namespace net { ProofSourceChromium::ProofSourceChromium() { } -bool ProofSourceChromium::GetProof(QuicVersion version, - const string& hostname, +bool ProofSourceChromium::GetProof(const string& hostname, const string& server_config, bool ecdsa_ok, const vector<string>** out_certs, diff --git a/chromium/net/quic/crypto/proof_source_chromium.h b/chromium/net/quic/crypto/proof_source_chromium.h index 2b93e2d9a4c..70ab92d91cf 100644 --- a/chromium/net/quic/crypto/proof_source_chromium.h +++ b/chromium/net/quic/crypto/proof_source_chromium.h @@ -23,8 +23,7 @@ class NET_EXPORT_PRIVATE ProofSourceChromium : public ProofSource { virtual ~ProofSourceChromium() {} // ProofSource interface - virtual bool GetProof(QuicVersion version, - const std::string& hostname, + virtual bool GetProof(const std::string& hostname, const std::string& server_config, bool ecdsa_ok, const std::vector<std::string>** out_certs, diff --git a/chromium/net/quic/crypto/proof_test.cc b/chromium/net/quic/crypto/proof_test.cc index 97b0dcb4ea9..e4e661a298c 100644 --- a/chromium/net/quic/crypto/proof_test.cc +++ b/chromium/net/quic/crypto/proof_test.cc @@ -25,21 +25,7 @@ using std::vector; namespace net { namespace test { -class ProofTest : public ::testing::TestWithParam<QuicVersion> { - protected: - ProofTest() { - version_ = GetParam(); - } - - QuicVersion version_; -}; - -// Run all ProofTests with QUIC versions 7 and 8. -INSTANTIATE_TEST_CASE_P(ProofTests, - ProofTest, - ::testing::Values(QUIC_VERSION_7, QUIC_VERSION_8)); - -TEST_P(ProofTest, Verify) { +TEST(ProofTest, Verify) { // TODO(rtenneti): Enable testing of ProofVerifier. #if 0 scoped_ptr<ProofSource> source(CryptoTestUtils::ProofSourceForTesting()); @@ -53,11 +39,10 @@ TEST_P(ProofTest, Verify) { string error_details, signature, first_signature; CertVerifyResult cert_verify_result; - ASSERT_TRUE(source->GetProof(version_, hostname, server_config, - false /* no ECDSA */, &first_certs, - &first_signature)); - ASSERT_TRUE(source->GetProof(version_, hostname, server_config, - false /* no ECDSA */, &certs, &signature)); + ASSERT_TRUE(source->GetProof(hostname, server_config, false /* no ECDSA */, + &first_certs, &first_signature)); + ASSERT_TRUE(source->GetProof(hostname, server_config, false /* no ECDSA */, + &certs, &signature)); // Check that the proof source is caching correctly: ASSERT_EQ(first_certs, certs); @@ -65,23 +50,22 @@ TEST_P(ProofTest, Verify) { int rv; TestCompletionCallback callback; - rv = verifier->VerifyProof(version_, hostname, server_config, *certs, - signature, &error_details, &cert_verify_result, + rv = verifier->VerifyProof(hostname, server_config, *certs, signature, + &error_details, &cert_verify_result, callback.callback()); rv = callback.GetResult(rv); ASSERT_EQ(OK, rv); ASSERT_EQ("", error_details); ASSERT_FALSE(IsCertStatusError(cert_verify_result.cert_status)); - rv = verifier->VerifyProof(version_, "foo.com", server_config, *certs, - signature, &error_details, &cert_verify_result, + rv = verifier->VerifyProof("foo.com", server_config, *certs, signature, + &error_details, &cert_verify_result, callback.callback()); rv = callback.GetResult(rv); ASSERT_EQ(ERR_FAILED, rv); ASSERT_NE("", error_details); - rv = verifier->VerifyProof(version_, hostname, - server_config.substr(1, string::npos), + rv = verifier->VerifyProof(hostname, server_config.substr(1, string::npos), *certs, signature, &error_details, &cert_verify_result, callback.callback()); rv = callback.GetResult(rv); @@ -89,7 +73,7 @@ TEST_P(ProofTest, Verify) { ASSERT_NE("", error_details); const string corrupt_signature = "1" + signature; - rv = verifier->VerifyProof(version_, hostname, server_config, *certs, + rv = verifier->VerifyProof(hostname, server_config, *certs, corrupt_signature, &error_details, &cert_verify_result, callback.callback()); rv = callback.GetResult(rv); @@ -100,8 +84,8 @@ TEST_P(ProofTest, Verify) { for (size_t i = 1; i < certs->size(); i++) { wrong_certs.push_back((*certs)[i]); } - rv = verifier->VerifyProof(version_, "foo.com", server_config, wrong_certs, - signature, &error_details, &cert_verify_result, + rv = verifier->VerifyProof("foo.com", server_config, wrong_certs, signature, + &error_details, &cert_verify_result, callback.callback()); rv = callback.GetResult(rv); ASSERT_EQ(ERR_FAILED, rv); @@ -138,8 +122,7 @@ class TestProofVerifierCallback : public ProofVerifierCallback { // RunVerification runs |verifier->VerifyProof| and asserts that the result // matches |expected_ok|. -static void RunVerification(QuicVersion version, - ProofVerifier* verifier, +static void RunVerification(ProofVerifier* verifier, const std::string& hostname, const std::string& server_config, const vector<std::string>& certs, @@ -153,7 +136,7 @@ static void RunVerification(QuicVersion version, new TestProofVerifierCallback(&comp_callback, &ok, &error_details); ProofVerifier::Status status = verifier->VerifyProof( - version, hostname, server_config, certs, proof, &error_details, &details, + hostname, server_config, certs, proof, &error_details, &details, callback); switch (status) { @@ -185,56 +168,11 @@ static string PEMCertFileToDER(const string& file_name) { // A known answer test that allows us to test ProofVerifier without a working // ProofSource. -TEST_P(ProofTest, VerifyRSAKnownAnswerTest) { +TEST(ProofTest, VerifyRSAKnownAnswerTest) { // These sample signatures were generated by running the Proof.Verify test // and dumping the bytes of the |signature| output of ProofSource::GetProof(). // sLen = special value -2 used by OpenSSL. static const unsigned char signature_data_0[] = { - 0x4c, 0x68, 0x3c, 0xc2, 0x1f, 0x31, 0x73, 0xa5, 0x29, 0xd3, - 0x56, 0x75, 0xb1, 0xbf, 0xbd, 0x31, 0x17, 0xfb, 0x2e, 0x24, - 0xb3, 0xc4, 0x0d, 0xfa, 0x56, 0xb8, 0x65, 0x94, 0x12, 0x38, - 0x6e, 0xff, 0xb3, 0x10, 0x2e, 0xf8, 0x5c, 0xc1, 0x21, 0x9d, - 0x29, 0x0c, 0x3a, 0x0a, 0x1a, 0xbf, 0x6b, 0x1c, 0x63, 0x77, - 0xf7, 0x86, 0xd3, 0xa4, 0x36, 0xf2, 0xb1, 0x6f, 0xac, 0xc3, - 0x23, 0x8d, 0xda, 0xe6, 0xd5, 0x83, 0xba, 0xdf, 0x28, 0x3e, - 0x7f, 0x4e, 0x79, 0xfc, 0xba, 0xdb, 0xf7, 0xd0, 0x4b, 0xad, - 0x79, 0xd0, 0xeb, 0xcf, 0xfa, 0x6e, 0x84, 0x44, 0x7a, 0x26, - 0xb1, 0x29, 0xa3, 0x08, 0xa8, 0x63, 0xfd, 0xed, 0x85, 0xff, - 0x9a, 0xe6, 0x79, 0x8b, 0xb6, 0x81, 0x13, 0x2c, 0xde, 0xe2, - 0xd8, 0x31, 0x29, 0xa4, 0xe0, 0x1b, 0x75, 0x2d, 0x8a, 0xf8, - 0x27, 0x55, 0xbc, 0xc7, 0x3b, 0x1e, 0xc1, 0x42, - }; - static const unsigned char signature_data_1[] = { - 0xbb, 0xd1, 0x17, 0x43, 0xf3, 0x42, 0x16, 0xe9, 0xf9, 0x76, - 0xe6, 0xe3, 0xaa, 0x50, 0x47, 0x5f, 0x93, 0xb6, 0x7d, 0x35, - 0x03, 0x49, 0x0a, 0x07, 0x61, 0xd5, 0xf1, 0x9c, 0x6b, 0xaf, - 0xaa, 0xd7, 0x64, 0xe4, 0x0a, 0x0c, 0xab, 0x97, 0xfb, 0x4e, - 0x5c, 0x14, 0x08, 0xf6, 0xb9, 0xa9, 0x1d, 0xa9, 0xf8, 0x6d, - 0xb0, 0x2b, 0x2a, 0x0e, 0xc4, 0xd0, 0xd2, 0xe9, 0x96, 0x4f, - 0x44, 0x70, 0x90, 0x46, 0xb9, 0xd5, 0x89, 0x72, 0xb9, 0xa8, - 0xe4, 0xfb, 0x88, 0xbc, 0x69, 0x7f, 0xc9, 0xdc, 0x84, 0x87, - 0x18, 0x21, 0x9b, 0xde, 0x22, 0x33, 0xde, 0x16, 0x3f, 0xe6, - 0xfd, 0x27, 0x56, 0xd3, 0xa4, 0x97, 0x91, 0x65, 0x1a, 0xe7, - 0x5e, 0x80, 0x9a, 0xbf, 0xbf, 0x1a, 0x29, 0x8a, 0xbe, 0xa2, - 0x8c, 0x9c, 0x23, 0xf4, 0xcb, 0xba, 0x79, 0x31, 0x28, 0xab, - 0x77, 0x94, 0x92, 0xb2, 0xc2, 0x35, 0xb2, 0xfa, - }; - static const unsigned char signature_data_2[] = { - 0x7e, 0x17, 0x01, 0xcb, 0x76, 0x9e, 0x9f, 0xce, 0xeb, 0x66, - 0x3e, 0xaa, 0xc9, 0x36, 0x5b, 0x7e, 0x48, 0x25, 0x99, 0xf8, - 0x0d, 0xe1, 0xa8, 0x48, 0x93, 0x3c, 0xe8, 0x97, 0x2e, 0x98, - 0xd6, 0x73, 0x0f, 0xd0, 0x74, 0x9c, 0x17, 0xef, 0xee, 0xf8, - 0x0e, 0x2a, 0x27, 0x3f, 0xc6, 0x55, 0xc6, 0xb9, 0xfe, 0x17, - 0xcc, 0xeb, 0x5d, 0xa1, 0xdc, 0xbd, 0x64, 0xd9, 0x5e, 0xec, - 0x57, 0x9d, 0xc3, 0xdc, 0x11, 0xbf, 0x23, 0x02, 0x58, 0xc4, - 0xf1, 0x18, 0xc1, 0x6f, 0x3f, 0xef, 0x18, 0x4d, 0xa6, 0x1e, - 0xe8, 0x25, 0x32, 0x8f, 0x92, 0x1e, 0xad, 0xbc, 0xbe, 0xde, - 0x83, 0x2a, 0x92, 0xd5, 0x59, 0x6f, 0xe4, 0x95, 0x6f, 0xe6, - 0xb1, 0xf9, 0xaf, 0x3f, 0xdb, 0x69, 0x6f, 0xae, 0xa6, 0x36, - 0xd2, 0x50, 0x81, 0x78, 0x41, 0x13, 0x2c, 0x65, 0x9c, 0x9e, - 0xf4, 0xd2, 0xd5, 0x58, 0x5b, 0x8b, 0x87, 0xcf, - }; - static const unsigned char signature_data_4[] = { 0x9e, 0xe6, 0x74, 0x3b, 0x8f, 0xb8, 0x66, 0x77, 0x57, 0x09, 0x8a, 0x04, 0xe9, 0xf0, 0x7c, 0x91, 0xa9, 0x5c, 0xe9, 0xdf, 0x12, 0x4d, 0x23, 0x82, 0x8c, 0x29, 0x72, 0x7f, 0xc2, 0x20, @@ -249,7 +187,7 @@ TEST_P(ProofTest, VerifyRSAKnownAnswerTest) { 0x78, 0xc8, 0x8b, 0xf5, 0xb9, 0x36, 0x5d, 0x72, 0x1f, 0xfc, 0x14, 0xff, 0xa7, 0x81, 0x27, 0x49, 0xae, 0xe1, }; - static const unsigned char signature_data_5[] = { + static const unsigned char signature_data_1[] = { 0x5e, 0xc2, 0xab, 0x6b, 0x16, 0xe6, 0x55, 0xf3, 0x16, 0x46, 0x35, 0xdc, 0xcc, 0xde, 0xd0, 0xbd, 0x6c, 0x66, 0xb2, 0x3d, 0xd3, 0x14, 0x78, 0xed, 0x47, 0x55, 0xfb, 0xdb, 0xe1, 0x7d, @@ -264,7 +202,7 @@ TEST_P(ProofTest, VerifyRSAKnownAnswerTest) { 0xaf, 0x6b, 0x47, 0xbc, 0x16, 0x55, 0x37, 0x0a, 0xbe, 0x0e, 0xc5, 0x75, 0x3f, 0x3d, 0x8e, 0xe8, 0x44, 0xe3, }; - static const unsigned char signature_data_6[] = { + static const unsigned char signature_data_2[] = { 0x8e, 0x5c, 0x78, 0x63, 0x74, 0x99, 0x2e, 0x96, 0xc0, 0x14, 0x8d, 0xb5, 0x13, 0x74, 0xa3, 0xa4, 0xe0, 0x43, 0x3e, 0x85, 0xba, 0x8f, 0x3c, 0x5e, 0x14, 0x64, 0x0e, 0x5e, 0xff, 0x89, @@ -295,52 +233,41 @@ TEST_P(ProofTest, VerifyRSAKnownAnswerTest) { // Signatures are nondeterministic, so we test multiple signatures on the // same server_config. vector<string> signatures(3); - if (version_ < QUIC_VERSION_8) { - signatures[0].assign(reinterpret_cast<const char*>(signature_data_0), - sizeof(signature_data_0)); - signatures[1].assign(reinterpret_cast<const char*>(signature_data_1), - sizeof(signature_data_1)); - signatures[2].assign(reinterpret_cast<const char*>(signature_data_2), - sizeof(signature_data_2)); - } else { - signatures[0].assign(reinterpret_cast<const char*>(signature_data_4), - sizeof(signature_data_4)); - signatures[1].assign(reinterpret_cast<const char*>(signature_data_5), - sizeof(signature_data_5)); - signatures[2].assign(reinterpret_cast<const char*>(signature_data_6), - sizeof(signature_data_6)); - } + signatures[0].assign(reinterpret_cast<const char*>(signature_data_0), + sizeof(signature_data_0)); + signatures[1].assign(reinterpret_cast<const char*>(signature_data_1), + sizeof(signature_data_1)); + signatures[2].assign(reinterpret_cast<const char*>(signature_data_2), + sizeof(signature_data_2)); for (size_t i = 0; i < signatures.size(); i++) { const string& signature = signatures[i]; RunVerification( - version_, verifier.get(), hostname, server_config, certs, signature, - true); + verifier.get(), hostname, server_config, certs, signature, true); RunVerification( - version_, verifier.get(), "foo.com", server_config, certs, signature, - false); + verifier.get(), "foo.com", server_config, certs, signature, false); RunVerification( - version_, verifier.get(), hostname, - server_config.substr(1, string::npos), certs, signature, false); + verifier.get(), hostname, server_config.substr(1, string::npos), + certs, signature, false); const string corrupt_signature = "1" + signature; RunVerification( - version_, verifier.get(), hostname, server_config, certs, - corrupt_signature, false); + verifier.get(), hostname, server_config, certs, corrupt_signature, + false); vector<string> wrong_certs; for (size_t i = 1; i < certs.size(); i++) { wrong_certs.push_back(certs[i]); } - RunVerification(version_, verifier.get(), hostname, server_config, - wrong_certs, signature, false); + RunVerification(verifier.get(), hostname, server_config, wrong_certs, + signature, false); } } // A known answer test that allows us to test ProofVerifier without a working // ProofSource. -TEST_P(ProofTest, VerifyECDSAKnownAnswerTest) { +TEST(ProofTest, VerifyECDSAKnownAnswerTest) { // Disable this test on platforms that do not support ECDSA certificates. #if defined(OS_WIN) if (base::win::GetVersion() < base::win::VERSION_VISTA) @@ -406,36 +333,34 @@ TEST_P(ProofTest, VerifyECDSAKnownAnswerTest) { const string& signature = signatures[i]; RunVerification( - version_, verifier.get(), hostname, server_config, certs, signature, - true); + verifier.get(), hostname, server_config, certs, signature, true); RunVerification( - version_, verifier.get(), "foo.com", server_config, certs, signature, - false); + verifier.get(), "foo.com", server_config, certs, signature, false); RunVerification( - version_, verifier.get(), hostname, - server_config.substr(1, string::npos), certs, signature, false); + verifier.get(), hostname, server_config.substr(1, string::npos), + certs, signature, false); // An ECDSA signature is DER-encoded. Corrupt the last byte so that the // signature can still be DER-decoded correctly. string corrupt_signature = signature; corrupt_signature[corrupt_signature.size() - 1] += 1; RunVerification( - version_, verifier.get(), hostname, server_config, certs, - corrupt_signature, false); + verifier.get(), hostname, server_config, certs, corrupt_signature, + false); // Prepending a "1" makes the DER invalid. const string bad_der_signature1 = "1" + signature; RunVerification( - version_, verifier.get(), hostname, server_config, certs, - bad_der_signature1, false); + verifier.get(), hostname, server_config, certs, bad_der_signature1, + false); vector<string> wrong_certs; for (size_t i = 1; i < certs.size(); i++) { wrong_certs.push_back(certs[i]); } RunVerification( - version_, verifier.get(), hostname, server_config, wrong_certs, - signature, false); + verifier.get(), hostname, server_config, wrong_certs, signature, + false); } } diff --git a/chromium/net/quic/crypto/proof_verifier.h b/chromium/net/quic/crypto/proof_verifier.h index ecab113e694..f469c552959 100644 --- a/chromium/net/quic/crypto/proof_verifier.h +++ b/chromium/net/quic/crypto/proof_verifier.h @@ -10,7 +10,6 @@ #include "net/base/completion_callback.h" #include "net/base/net_export.h" -#include "net/quic/quic_protocol.h" namespace net { @@ -71,11 +70,7 @@ class NET_EXPORT_PRIVATE ProofVerifier { // // The signature uses SHA-256 as the hash function and PSS padding in the // case of RSA. - // - // |version| is the QUIC version for the connection. TODO(wtc): Remove once - // QUIC_VERSION_7 and before are removed. - virtual Status VerifyProof(QuicVersion version, - const std::string& hostname, + virtual Status VerifyProof(const std::string& hostname, const std::string& server_config, const std::vector<std::string>& certs, const std::string& signature, diff --git a/chromium/net/quic/crypto/proof_verifier_chromium.cc b/chromium/net/quic/crypto/proof_verifier_chromium.cc index 88653053f3e..8c4796204ec 100644 --- a/chromium/net/quic/crypto/proof_verifier_chromium.cc +++ b/chromium/net/quic/crypto/proof_verifier_chromium.cc @@ -42,7 +42,6 @@ ProofVerifierChromium::~ProofVerifierChromium() { } ProofVerifierChromium::Status ProofVerifierChromium::VerifyProof( - QuicVersion version, const string& hostname, const string& server_config, const vector<string>& certs, @@ -90,7 +89,7 @@ ProofVerifierChromium::Status ProofVerifierChromium::VerifyProof( // We call VerifySignature first to avoid copying of server_config and // signature. - if (!VerifySignature(version, server_config, signature, certs[0])) { + if (!VerifySignature(server_config, signature, certs[0])) { *error_details = "Failed to verify signature of server config"; DLOG(WARNING) << *error_details; verify_details_->cert_verify_result.cert_status = CERT_STATUS_INVALID; @@ -177,8 +176,7 @@ int ProofVerifierChromium::DoVerifyCertComplete(int result) { return result; } -bool ProofVerifierChromium::VerifySignature(QuicVersion version, - const string& signed_data, +bool ProofVerifierChromium::VerifySignature(const string& signed_data, const string& signature, const string& cert) { StringPiece spki; @@ -198,11 +196,9 @@ bool ProofVerifierChromium::VerifySignature(QuicVersion version, crypto::SignatureVerifier::SHA256; crypto::SignatureVerifier::HashAlgorithm mask_hash_alg = hash_alg; unsigned int hash_len = 32; // 32 is the length of a SHA-256 hash. - unsigned int salt_len = - version >= QUIC_VERSION_8 ? hash_len : signature.size() - hash_len - 2; bool ok = verifier.VerifyInitRSAPSS( - hash_alg, mask_hash_alg, salt_len, + hash_alg, mask_hash_alg, hash_len, reinterpret_cast<const uint8*>(signature.data()), signature.size(), reinterpret_cast<const uint8*>(spki.data()), spki.size()); if (!ok) { diff --git a/chromium/net/quic/crypto/proof_verifier_chromium.h b/chromium/net/quic/crypto/proof_verifier_chromium.h index 8786e52e7dd..4969cc8aa55 100644 --- a/chromium/net/quic/crypto/proof_verifier_chromium.h +++ b/chromium/net/quic/crypto/proof_verifier_chromium.h @@ -39,8 +39,7 @@ class NET_EXPORT_PRIVATE ProofVerifierChromium : public ProofVerifier { virtual ~ProofVerifierChromium(); // ProofVerifier interface - virtual Status VerifyProof(QuicVersion version, - const std::string& hostname, + virtual Status VerifyProof(const std::string& hostname, const std::string& server_config, const std::vector<std::string>& certs, const std::string& signature, @@ -60,8 +59,7 @@ class NET_EXPORT_PRIVATE ProofVerifierChromium : public ProofVerifier { int DoVerifyCert(int result); int DoVerifyCertComplete(int result); - bool VerifySignature(QuicVersion version, - const std::string& signed_data, + bool VerifySignature(const std::string& signed_data, const std::string& signature, const std::string& cert); diff --git a/chromium/net/quic/crypto/source_address_token.cc b/chromium/net/quic/crypto/source_address_token.cc index d15afebf2a7..b095e762265 100644 --- a/chromium/net/quic/crypto/source_address_token.cc +++ b/chromium/net/quic/crypto/source_address_token.cc @@ -21,24 +21,36 @@ SourceAddressToken::~SourceAddressToken() { } string SourceAddressToken::SerializeAsString() const { - return ip_ + " " + base::Int64ToString(timestamp_); + string out; + out.push_back(ip_.size()); + out.append(ip_); + string time_str = base::Int64ToString(timestamp_); + out.push_back(time_str.size()); + out.append(time_str); + return out; } bool SourceAddressToken::ParseFromArray(const char* plaintext, size_t plaintext_length) { - string data(plaintext, plaintext_length); - vector<string> results; - base::SplitString(data, ' ', &results); - if (results.size() < 2) { + if (plaintext_length == 0) { + return false; + } + size_t ip_len = plaintext[0]; + if (plaintext_length <= 1 + ip_len) { + return false; + } + size_t time_len = plaintext[1 + ip_len]; + if (plaintext_length != 1 + ip_len + 1 + time_len) { return false; } + string time_str(&plaintext[1 + ip_len + 1], time_len); int64 timestamp; - if (!base::StringToInt64(results[1], ×tamp)) { + if (!base::StringToInt64(time_str, ×tamp)) { return false; } - ip_ = results[0]; + ip_.assign(&plaintext[1], ip_len); timestamp_ = timestamp; return true; } diff --git a/chromium/net/quic/crypto/strike_register.cc b/chromium/net/quic/crypto/strike_register.cc index 97aca184cd0..f45bfabd9f0 100644 --- a/chromium/net/quic/crypto/strike_register.cc +++ b/chromium/net/quic/crypto/strike_register.cc @@ -56,8 +56,8 @@ class StrikeRegister::InternalNode { }; // kCreationTimeFromInternalEpoch contains the number of seconds between the -// start of the internal epoch and |creation_time_external_|. This allows us -// to consider times that are before |creation_time_external_|. +// start of the internal epoch and the creation time. This allows us +// to consider times that are before the creation time. static const uint32 kCreationTimeFromInternalEpoch = 63115200.0; // 2 years. StrikeRegister::StrikeRegister(unsigned max_entries, @@ -67,22 +67,17 @@ StrikeRegister::StrikeRegister(unsigned max_entries, StartupType startup) : max_entries_(max_entries), window_secs_(window_secs), + internal_epoch_(current_time > kCreationTimeFromInternalEpoch + ? current_time - kCreationTimeFromInternalEpoch + : 0), // The horizon is initially set |window_secs| into the future because, if // we just crashed, then we may have accepted nonces in the span // [current_time...current_time+window_secs) and so we conservatively // reject the whole timespan unless |startup| tells us otherwise. - creation_time_external_(current_time), - internal_epoch_(current_time > kCreationTimeFromInternalEpoch - ? current_time - kCreationTimeFromInternalEpoch - : 0), horizon_(ExternalTimeToInternal(current_time) + window_secs), horizon_valid_(startup == DENY_REQUESTS_AT_STARTUP) { memcpy(orbit_, orbit, sizeof(orbit_)); - // TODO(rtenneti): Remove the following check, Added the following to silence - // "is not used" error. - CHECK_GE(creation_time_external_, 0u); - // We only have 23 bits of index available. CHECK_LT(max_entries, 1u << 23); CHECK_GT(max_entries, 1u); // There must be at least two entries. diff --git a/chromium/net/quic/crypto/strike_register.h b/chromium/net/quic/crypto/strike_register.h index 98bc04cb630..fda62a802b2 100644 --- a/chromium/net/quic/crypto/strike_register.h +++ b/chromium/net/quic/crypto/strike_register.h @@ -129,7 +129,7 @@ class NET_EXPORT_PRIVATE StrikeRegister { static uint32 TimeFromBytes(const uint8 d[4]); // ExternalTimeToInternal converts an external time value into an internal - // time value using |creation_time_external_|. + // time value using |internal_epoch_|. uint32 ExternalTimeToInternal(uint32 external_time); // BestMatch returns either kNil, or an external node index which could @@ -164,10 +164,6 @@ class NET_EXPORT_PRIVATE StrikeRegister { const uint32 max_entries_; const uint32 window_secs_; - // creation_time_external_ contains the uint32, external time when this - // object was created (i.e. the value passed to the constructor). This is - // used to translate external times to internal times. - const uint32 creation_time_external_; // internal_epoch_ contains the external time value of the start of internal // time. const uint32 internal_epoch_; diff --git a/chromium/net/quic/quic_ack_notifier.cc b/chromium/net/quic/quic_ack_notifier.cc new file mode 100644 index 00000000000..662b432d5c8 --- /dev/null +++ b/chromium/net/quic/quic_ack_notifier.cc @@ -0,0 +1,56 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/quic/quic_ack_notifier.h" + +namespace net { + +QuicAckNotifier::DelegateInterface::DelegateInterface() {} + +QuicAckNotifier::DelegateInterface::~DelegateInterface() {} + +QuicAckNotifier::QuicAckNotifier(DelegateInterface* delegate) + : delegate_(delegate) { + DCHECK(delegate_); +} + +QuicAckNotifier::~QuicAckNotifier() {} + +void QuicAckNotifier::AddSequenceNumber( + const QuicPacketSequenceNumber& sequence_number) { + sequence_numbers_.insert(sequence_number); +} + +void QuicAckNotifier::AddSequenceNumbers( + const SequenceNumberSet& sequence_numbers) { + for (SequenceNumberSet::const_iterator it = sequence_numbers.begin(); + it != sequence_numbers.end(); ++it) { + AddSequenceNumber(*it); + } +} + +bool QuicAckNotifier::OnAck(SequenceNumberSet sequence_numbers) { + // If the set of sequence numbers we are tracking is empty then this + // QuicAckNotifier should have already been deleted. + DCHECK(!sequence_numbers_.empty()); + + for (SequenceNumberSet::iterator it = sequence_numbers.begin(); + it != sequence_numbers.end(); ++it) { + sequence_numbers_.erase(*it); + if (sequence_numbers_.empty()) { + delegate_->OnAckNotification(); + return true; + } + } + return false; +} + +void QuicAckNotifier::UpdateSequenceNumber( + QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number) { + sequence_numbers_.erase(old_sequence_number); + sequence_numbers_.insert(new_sequence_number); +} + +}; // namespace net diff --git a/chromium/net/quic/quic_ack_notifier.h b/chromium/net/quic/quic_ack_notifier.h new file mode 100644 index 00000000000..8470681c650 --- /dev/null +++ b/chromium/net/quic/quic_ack_notifier.h @@ -0,0 +1,66 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_QUIC_QUIC_ACK_NOTIFIER_H_ +#define NET_QUIC_QUIC_ACK_NOTIFIER_H_ + +#include "base/callback.h" +#include "net/quic/quic_protocol.h" + +namespace net { + +// Used to register with a QuicConnection for notification once a set of packets +// have all been ACKed. +// The connection informs this class of newly ACKed sequence numbers, and once +// we have seen ACKs for all the sequence numbers we are interested in, we +// trigger a call to a provided Closure. +class NET_EXPORT_PRIVATE QuicAckNotifier { + public: + class NET_EXPORT_PRIVATE DelegateInterface { + public: + DelegateInterface(); + virtual ~DelegateInterface(); + virtual void OnAckNotification() = 0; + }; + + explicit QuicAckNotifier(DelegateInterface* delegate); + virtual ~QuicAckNotifier(); + + // Register a sequence number that this AckNotifier should be interested in. + void AddSequenceNumber(const QuicPacketSequenceNumber& sequence_number); + + // Register a set of sequence numbers that this AckNotifier should be + // interested in. + void AddSequenceNumbers(const SequenceNumberSet& sequence_numbers); + + // Called by the QuicConnection on receipt of new ACK frames with a list of + // ACKed sequence numbers. + // Deletes any matching sequence numbers from the set of sequence numbers + // being tracked. If this set is now empty, call the stored delegate's + // OnAckNotification method. + // + // Returns true if the provided sequence_numbers caused the delegate to be + // called, false otherwise. + bool OnAck(SequenceNumberSet sequence_numbers); + + // If a packet is retransmitted by the connection it will be sent with a + // different sequence number. Updates our internal set of sequence_numbers to + // track the latest number. + void UpdateSequenceNumber(QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number); + + private: + // The delegate's OnAckNotification() method will be called once we have been + // notified of ACKs for all the sequence numbers we are tracking. + // This is not owned by OnAckNotifier and must outlive it. + DelegateInterface* delegate_; + + // Set of sequence numbers this notifier is waiting to hear about. The + // delegate will not be called until this is an empty set. + SequenceNumberSet sequence_numbers_; +}; + +}; // namespace net + +#endif // NET_QUIC_QUIC_ACK_NOTIFIER_H_ diff --git a/chromium/net/quic/quic_ack_notifier_test.cc b/chromium/net/quic/quic_ack_notifier_test.cc new file mode 100644 index 00000000000..aa1ebfa955f --- /dev/null +++ b/chromium/net/quic/quic_ack_notifier_test.cc @@ -0,0 +1,106 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/quic/quic_ack_notifier.h" + +#include "net/quic/test_tools/quic_test_utils.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace test { +namespace { + +class QuicAckNotifierTest : public ::testing::Test { + protected: + virtual void SetUp() { + notifier_.reset(new QuicAckNotifier(&delegate_)); + + sequence_numbers_.insert(26); + sequence_numbers_.insert(99); + sequence_numbers_.insert(1234); + notifier_->AddSequenceNumbers(sequence_numbers_); + } + + SequenceNumberSet sequence_numbers_; + MockAckNotifierDelegate delegate_; + scoped_ptr<QuicAckNotifier> notifier_; +}; + +// Should trigger callback when we receive acks for all the registered seqnums. +TEST_F(QuicAckNotifierTest, TriggerCallback) { + EXPECT_CALL(delegate_, OnAckNotification()).Times(1); + EXPECT_TRUE(notifier_->OnAck(sequence_numbers_)); +} + +// Should trigger callback when we receive acks for all the registered seqnums, +// even though they are interspersed with other seqnums. +TEST_F(QuicAckNotifierTest, TriggerCallbackInterspersed) { + sequence_numbers_.insert(3); + sequence_numbers_.insert(55); + sequence_numbers_.insert(805); + + EXPECT_CALL(delegate_, OnAckNotification()).Times(1); + EXPECT_TRUE(notifier_->OnAck(sequence_numbers_)); +} + +// Should trigger callback when we receive acks for all the registered seqnums, +// even though they are split over multiple calls to OnAck. +TEST_F(QuicAckNotifierTest, TriggerCallbackMultipleCalls) { + SequenceNumberSet seqnums; + seqnums.insert(26); + EXPECT_CALL(delegate_, OnAckNotification()).Times(0); + EXPECT_FALSE(notifier_->OnAck(seqnums)); + + seqnums.clear(); + seqnums.insert(55); + seqnums.insert(9001); + seqnums.insert(99); + EXPECT_CALL(delegate_, OnAckNotification()).Times(0); + EXPECT_FALSE(notifier_->OnAck(seqnums)); + + seqnums.clear(); + seqnums.insert(1234); + EXPECT_CALL(delegate_, OnAckNotification()).Times(1); + EXPECT_TRUE(notifier_->OnAck(seqnums)); +} + +// Should not trigger callback if we never provide all the seqnums. +TEST_F(QuicAckNotifierTest, DoesNotTrigger) { + SequenceNumberSet different_seqnums; + different_seqnums.insert(14); + different_seqnums.insert(15); + different_seqnums.insert(16); + + // Should not trigger callback as not all packets have been seen. + EXPECT_CALL(delegate_, OnAckNotification()).Times(0); + EXPECT_FALSE(notifier_->OnAck(different_seqnums)); +} + +// Should trigger even after updating sequence numbers and receiving ACKs for +// new sequeunce numbers. +TEST_F(QuicAckNotifierTest, UpdateSeqNums) { + // Uninteresting sequeunce numbers shouldn't trigger callback. + SequenceNumberSet seqnums; + seqnums.insert(6); + seqnums.insert(7); + seqnums.insert(2000); + EXPECT_CALL(delegate_, OnAckNotification()).Times(0); + EXPECT_FALSE(notifier_->OnAck(seqnums)); + + // Update a couple of the sequence numbers (i.e. retransmitted packets) + notifier_->UpdateSequenceNumber(99, 3000); + notifier_->UpdateSequenceNumber(1234, 3001); + + seqnums.clear(); + seqnums.insert(26); // original, unchanged + seqnums.insert(3000); // updated + seqnums.insert(3001); // updated + EXPECT_CALL(delegate_, OnAckNotification()).Times(1); + EXPECT_TRUE(notifier_->OnAck(seqnums)); +} + +} // namespace +} // namespace test +} // namespace net diff --git a/chromium/net/quic/quic_client_session.cc b/chromium/net/quic/quic_client_session.cc index d7fb0d2675e..ca6941c7e1b 100644 --- a/chromium/net/quic/quic_client_session.cc +++ b/chromium/net/quic/quic_client_session.cc @@ -81,7 +81,7 @@ void QuicClientSession::StreamRequest::OnRequestCompleteFailure(int rv) { QuicClientSession::QuicClientSession( QuicConnection* connection, - DatagramClientSocket* socket, + scoped_ptr<DatagramClientSocket> socket, QuicStreamFactory* stream_factory, QuicCryptoClientStreamFactory* crypto_client_stream_factory, const string& server_hostname, @@ -89,14 +89,15 @@ QuicClientSession::QuicClientSession( QuicCryptoClientConfig* crypto_config, NetLog* net_log) : QuicSession(connection, config, false), - weak_factory_(this), + require_confirmation_(false), stream_factory_(stream_factory), - socket_(socket), + socket_(socket.Pass()), read_buffer_(new IOBufferWithSize(kMaxPacketSize)), read_pending_(false), num_total_streams_(0), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_QUIC_SESSION)), - logger_(net_log_) { + logger_(net_log_), + weak_factory_(this) { crypto_stream_.reset( crypto_client_stream_factory ? crypto_client_stream_factory->CreateQuicCryptoClientStream( @@ -209,7 +210,9 @@ bool QuicClientSession::GetSSLInfo(SSLInfo* ssl_info) { return crypto_stream_->GetSSLInfo(ssl_info); } -int QuicClientSession::CryptoConnect(const CompletionCallback& callback) { +int QuicClientSession::CryptoConnect(bool require_confirmation, + const CompletionCallback& callback) { + require_confirmation_ = require_confirmation; RecordHandshakeState(STATE_STARTED); if (!crypto_stream_->CryptoConnect()) { // TODO(wtc): change crypto_stream_.CryptoConnect() to return a @@ -217,7 +220,9 @@ int QuicClientSession::CryptoConnect(const CompletionCallback& callback) { return ERR_CONNECTION_FAILED; } - if (IsEncryptionEstablished()) { + bool can_notify = require_confirmation_ ? + IsCryptoHandshakeConfirmed() : IsEncryptionEstablished(); + if (can_notify) { return OK; } @@ -225,6 +230,10 @@ int QuicClientSession::CryptoConnect(const CompletionCallback& callback) { return ERR_IO_PENDING; } +int QuicClientSession::GetNumSentClientHellos() const { + return crypto_stream_->num_sent_client_hellos(); +} + ReliableQuicStream* QuicClientSession::CreateIncomingReliableStream( QuicStreamId id) { DLOG(ERROR) << "Server push not supported"; @@ -233,7 +242,16 @@ ReliableQuicStream* QuicClientSession::CreateIncomingReliableStream( void QuicClientSession::CloseStream(QuicStreamId stream_id) { QuicSession::CloseStream(stream_id); + OnClosedStream(); +} + +void QuicClientSession::SendRstStream(QuicStreamId id, + QuicRstStreamErrorCode error) { + QuicSession::SendRstStream(id, error); + OnClosedStream(); +} +void QuicClientSession::OnClosedStream() { if (GetNumOpenStreams() < get_max_open_streams() && !stream_requests_.empty() && crypto_stream_->encryption_established() && @@ -250,7 +268,8 @@ void QuicClientSession::CloseStream(QuicStreamId stream_id) { } void QuicClientSession::OnCryptoHandshakeEvent(CryptoHandshakeEvent event) { - if (!callback_.is_null()) { + if (!callback_.is_null() && + (!require_confirmation_ || event == HANDSHAKE_CONFIRMED)) { // TODO(rtenneti): Currently for all CryptoHandshakeEvent events, callback_ // could be called because there are no error events in CryptoHandshakeEvent // enum. If error events are added to CryptoHandshakeEvent, then the @@ -260,9 +279,32 @@ void QuicClientSession::OnCryptoHandshakeEvent(CryptoHandshakeEvent event) { QuicSession::OnCryptoHandshakeEvent(event); } +void QuicClientSession::OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message) { + logger_.OnCryptoHandshakeMessageSent(message); +} + +void QuicClientSession::OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message) { + logger_.OnCryptoHandshakeMessageReceived(message); +} + void QuicClientSession::ConnectionClose(QuicErrorCode error, bool from_peer) { - UMA_HISTOGRAM_SPARSE_SLOWLY("Net.QuicSession.ConnectionCloseErrorCode", - error); + logger_.OnConnectionClose(error, from_peer); + if (from_peer) { + UMA_HISTOGRAM_SPARSE_SLOWLY( + "Net.QuicSession.ConnectionCloseErrorCodeServer", error); + } else { + UMA_HISTOGRAM_SPARSE_SLOWLY( + "Net.QuicSession.ConnectionCloseErrorCodeClient", error); + } + + if (error == QUIC_CONNECTION_TIMED_OUT) { + UMA_HISTOGRAM_SPARSE_SLOWLY( + "Net.QuicSession.ConnectionClose.NumOpenStreams.TimedOut", + GetNumOpenStreams()); + } + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.QuicSession.QuicVersion", connection()->version()); if (!callback_.is_null()) { @@ -272,6 +314,12 @@ void QuicClientSession::ConnectionClose(QuicErrorCode error, bool from_peer) { NotifyFactoryOfSessionCloseLater(); } +void QuicClientSession::OnSuccessfulVersionNegotiation( + const QuicVersion& version) { + logger_.OnSuccessfulVersionNegotiation(version); + QuicSession::OnSuccessfulVersionNegotiation(version); +} + void QuicClientSession::StartReading() { if (read_pending_) { return; @@ -296,25 +344,26 @@ void QuicClientSession::StartReading() { void QuicClientSession::CloseSessionOnError(int error) { UMA_HISTOGRAM_SPARSE_SLOWLY("Net.QuicSession.CloseSessionOnError", -error); - CloseSessionOnErrorInner(error); + CloseSessionOnErrorInner(error, QUIC_INTERNAL_ERROR); NotifyFactoryOfSessionClose(); } -void QuicClientSession::CloseSessionOnErrorInner(int error) { +void QuicClientSession::CloseSessionOnErrorInner(int net_error, + QuicErrorCode quic_error) { if (!callback_.is_null()) { - base::ResetAndReturn(&callback_).Run(error); + base::ResetAndReturn(&callback_).Run(net_error); } while (!streams()->empty()) { ReliableQuicStream* stream = streams()->begin()->second; QuicStreamId id = stream->id(); - static_cast<QuicReliableClientStream*>(stream)->OnError(error); + static_cast<QuicReliableClientStream*>(stream)->OnError(net_error); CloseStream(id); } net_log_.AddEvent( NetLog::TYPE_QUIC_SESSION_CLOSE_ON_ERROR, - NetLog::IntegerCallback("net_error", error)); + NetLog::IntegerCallback("net_error", net_error)); - connection()->CloseConnection(QUIC_INTERNAL_ERROR, false); + connection()->CloseConnection(quic_error, false); DCHECK(!connection()->connected()); } @@ -340,7 +389,8 @@ void QuicClientSession::OnReadComplete(int result) { if (result < 0) { DLOG(INFO) << "Closing session on read error: " << result; - CloseSessionOnErrorInner(result); + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.QuicSession.ReadError", -result); + CloseSessionOnErrorInner(result, QUIC_PACKET_READ_ERROR); NotifyFactoryOfSessionCloseLater(); return; } diff --git a/chromium/net/quic/quic_client_session.h b/chromium/net/quic/quic_client_session.h index 339c40b2ba1..d167237d809 100644 --- a/chromium/net/quic/quic_client_session.h +++ b/chromium/net/quic/quic_client_session.h @@ -13,6 +13,7 @@ #include <string> #include "base/containers/hash_tables.h" +#include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/quic/quic_connection_logger.h" #include "net/quic/quic_crypto_client_stream.h" @@ -74,7 +75,7 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { // not |stream_factory|, which must outlive this session. // TODO(rch): decouple the factory from the session via a Delegate interface. QuicClientSession(QuicConnection* connection, - DatagramClientSocket* socket, + scoped_ptr<DatagramClientSocket> socket, QuicStreamFactory* stream_factory, QuicCryptoClientStreamFactory* crypto_client_stream_factory, const std::string& server_hostname, @@ -101,14 +102,23 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { virtual QuicReliableClientStream* CreateOutgoingReliableStream() OVERRIDE; virtual QuicCryptoClientStream* GetCryptoStream() OVERRIDE; virtual void CloseStream(QuicStreamId stream_id) OVERRIDE; + virtual void SendRstStream(QuicStreamId id, + QuicRstStreamErrorCode error) OVERRIDE; virtual void OnCryptoHandshakeEvent(CryptoHandshakeEvent event) OVERRIDE; + virtual void OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message) OVERRIDE; + virtual void OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message) OVERRIDE; virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; // QuicConnectionVisitorInterface methods: virtual void ConnectionClose(QuicErrorCode error, bool from_peer) OVERRIDE; + virtual void OnSuccessfulVersionNegotiation( + const QuicVersion& version) OVERRIDE; // Performs a crypto handshake with the server. - int CryptoConnect(const CompletionCallback& callback); + int CryptoConnect(bool require_confirmation, + const CompletionCallback& callback); // Causes the QuicConnectionHelper to start reading from the socket // and passing the data along to the QuicConnection. @@ -124,6 +134,11 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { base::WeakPtr<QuicClientSession> GetWeakPtr(); + // Returns the number of client hello messages that have been sent on the + // crypto stream. If the handshake has completed then this is one greater + // than the number of round-trips needed for the handshake. + int GetNumSentClientHellos() const; + protected: // QuicSession methods: virtual ReliableQuicStream* CreateIncomingReliableStream( @@ -138,7 +153,9 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { // A completion callback invoked when a read completes. void OnReadComplete(int result); - void CloseSessionOnErrorInner(int error); + void OnClosedStream(); + + void CloseSessionOnErrorInner(int net_error, QuicErrorCode quic_error); // Posts a task to notify the factory that this session has been closed. void NotifyFactoryOfSessionCloseLater(); @@ -147,7 +164,7 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { // delete |this|. void NotifyFactoryOfSessionClose(); - base::WeakPtrFactory<QuicClientSession> weak_factory_; + bool require_confirmation_; scoped_ptr<QuicCryptoClientStream> crypto_stream_; QuicStreamFactory* stream_factory_; scoped_ptr<DatagramClientSocket> socket_; @@ -158,6 +175,7 @@ class NET_EXPORT_PRIVATE QuicClientSession : public QuicSession { size_t num_total_streams_; BoundNetLog net_log_; QuicConnectionLogger logger_; + base::WeakPtrFactory<QuicClientSession> weak_factory_; DISALLOW_COPY_AND_ASSIGN(QuicClientSession); }; diff --git a/chromium/net/quic/quic_client_session_test.cc b/chromium/net/quic/quic_client_session_test.cc index 6113f4587c1..385be52a5b9 100644 --- a/chromium/net/quic/quic_client_session_test.cc +++ b/chromium/net/quic/quic_client_session_test.cc @@ -15,6 +15,7 @@ #include "net/quic/test_tools/crypto_test_utils.h" #include "net/quic/test_tools/quic_client_session_peer.h" #include "net/quic/test_tools/quic_test_utils.h" +#include "net/udp/datagram_client_socket.h" using testing::_; @@ -29,15 +30,16 @@ class QuicClientSessionTest : public ::testing::Test { QuicClientSessionTest() : guid_(1), connection_(new PacketSavingConnection(guid_, IPEndPoint(), false)), - session_(connection_, NULL, NULL, NULL, kServerHostname, - DefaultQuicConfig(), &crypto_config_, &net_log_) { + session_(connection_, scoped_ptr<DatagramClientSocket>(), NULL, + NULL, kServerHostname, DefaultQuicConfig(), &crypto_config_, + &net_log_) { session_.config()->SetDefaults(); crypto_config_.SetDefaults(); } void CompleteCryptoHandshake() { ASSERT_EQ(ERR_IO_PENDING, - session_.CryptoConnect(callback_.callback())); + session_.CryptoConnect(false, callback_.callback())); CryptoTestUtils::HandshakeWithFakeServer( connection_, session_.GetCryptoStream()); ASSERT_EQ(OK, callback_.WaitForResult()); diff --git a/chromium/net/quic/quic_connection.cc b/chromium/net/quic/quic_connection.cc index fefeba95f2d..7417bd435a1 100644 --- a/chromium/net/quic/quic_connection.cc +++ b/chromium/net/quic/quic_connection.cc @@ -19,6 +19,7 @@ using std::list; using std::make_pair; using std::min; using std::max; +using std::numeric_limits; using std::vector; using std::set; using std::string; @@ -32,7 +33,7 @@ const QuicPacketSequenceNumber kMaxPacketGap = 5000; // We want to make sure if we get a large nack packet, we don't queue up too // many packets at once. 10 is arbitrary. -const int kMaxRetransmissionsPerAck = 10; +const size_t kMaxRetransmissionsPerAck = 10; // TCP retransmits after 2 nacks. We allow for a third in case of out-of-order // delivery. @@ -40,11 +41,6 @@ const int kMaxRetransmissionsPerAck = 10; // at least 3 sequence numbers larger arrives. const size_t kNumberOfNacksBeforeRetransmission = 3; -// The maxiumum number of packets we'd like to queue. We may end up queueing -// more in the case of many control frames. -// 6 is arbitrary. -const int kMaxPacketsToSerializeAtOnce = 6; - // Limit the number of packets we send per retransmission-alarm so we // eventually cede. 10 is arbitrary. const size_t kMaxPacketsPerRetransmissionAlarm = 10; @@ -130,8 +126,31 @@ class TimeoutAlarm : public QuicAlarm::Delegate { QuicConnection* connection_; }; +// Indicates if any of the frames are intended to be sent with FORCE. +// Returns true when one of the frames is a CONNECTION_CLOSE_FRAME. +net::QuicConnection::Force HasForcedFrames( + const RetransmittableFrames* retransmittable_frames) { + if (!retransmittable_frames) { + return net::QuicConnection::NO_FORCE; + } + for (size_t i = 0; i < retransmittable_frames->frames().size(); ++i) { + if (retransmittable_frames->frames()[i].type == CONNECTION_CLOSE_FRAME) { + return net::QuicConnection::FORCE; + } + } + return net::QuicConnection::NO_FORCE; +} + } // namespace +// TODO(rch): Remove this. +// Because of a bug in the interaction between the TcpCubicSender and +// QuicConnection, acks currently count against the congestion window. +// This means that if acks are not acked, and data is only flowing in +// one direction, then the connection will deadlock. +// static +bool QuicConnection::g_acks_do_not_instigate_acks = false; + #define ENDPOINT (is_server_ ? "Server: " : " Client: ") QuicConnection::QuicConnection(QuicGuid guid, @@ -165,6 +184,7 @@ QuicConnection::QuicConnection(QuicGuid guid, time_of_last_received_packet_(clock_->ApproximateNow()), time_of_last_sent_packet_(clock_->ApproximateNow()), congestion_manager_(clock_, kTCP), + sent_packet_manager_(is_server, this), version_negotiation_state_(START_NEGOTIATION), max_packets_per_retransmission_alarm_(kMaxPacketsPerRetransmissionAlarm), is_server_(is_server), @@ -187,8 +207,8 @@ QuicConnection::QuicConnection(QuicGuid guid, } QuicConnection::~QuicConnection() { + STLDeleteElements(&ack_notifiers_); STLDeleteElements(&undecryptable_packets_); - STLDeleteValues(&unacked_packets_); STLDeleteValues(&group_map_); for (QueuedPacketList::iterator it = queued_packets_.begin(); it != queued_packets_.end(); ++it) { @@ -278,6 +298,7 @@ bool QuicConnection::OnProtocolVersionMismatch(QuicVersion received_version) { } version_negotiation_state_ = NEGOTIATED_VERSION; + visitor_->OnSuccessfulVersionNegotiation(received_version); // Store the new version. framer_.set_version(received_version); @@ -376,6 +397,7 @@ bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { DCHECK_EQ(1u, header.public_header.versions.size()); DCHECK_EQ(header.public_header.versions[0], version()); version_negotiation_state_ = NEGOTIATED_VERSION; + visitor_->OnSuccessfulVersionNegotiation(version()); } } else { DCHECK(!header.public_header.version_flag); @@ -383,6 +405,7 @@ bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { // it should stop sending version since the version negotiation is done. packet_creator_.StopSendingVersion(); version_negotiation_state_ = NEGOTIATED_VERSION; + visitor_->OnSuccessfulVersionNegotiation(version()); } } @@ -429,6 +452,17 @@ bool QuicConnection::OnAckFrame(const QuicAckFrame& incoming_ack) { SendConnectionClose(QUIC_INVALID_ACK_DATA); return false; } + + // Reset the RTO timeout for each packet when an ack is received. + if (retransmission_alarm_->IsSet()) { + retransmission_alarm_->Cancel(); + QuicTime::Delta retransmission_delay = + congestion_manager_.GetRetransmissionDelay( + sent_packet_manager_.GetNumUnackedPackets(), 0); + retransmission_alarm_->Set(clock_->ApproximateNow().Add( + retransmission_delay)); + } + last_ack_frames_.push_back(incoming_ack); return connected_; } @@ -448,11 +482,34 @@ void QuicConnection::ProcessAckFrame(const QuicAckFrame& incoming_ack) { sent_entropy_manager_.ClearEntropyBefore( received_packet_manager_.least_packet_awaited_by_peer() - 1); + retransmitted_nacked_packet_count_ = 0; SequenceNumberSet acked_packets; - HandleAckForSentPackets(incoming_ack, &acked_packets); - HandleAckForSentFecPackets(incoming_ack, &acked_packets); + sent_packet_manager_.HandleAckForSentPackets(incoming_ack, &acked_packets); + sent_packet_manager_.HandleAckForSentFecPackets(incoming_ack, &acked_packets); if (acked_packets.size() > 0) { - visitor_->OnAck(acked_packets); + // Inform all the registered AckNotifiers of the new ACKs. + // TODO(rjshade): Make this more efficient by maintaining a mapping of + // <sequence number, set<AckNotifierList>> so that OnAck + // is only called on AckNotifiers that care about the + // packets being ACKed. + AckNotifierList::iterator it = ack_notifiers_.begin(); + while (it != ack_notifiers_.end()) { + if ((*it)->OnAck(acked_packets)) { + // The QuicAckNotifier has seen all the ACKs it was interested in, and + // has triggered its callback. No more use for it. + delete *it; + it = ack_notifiers_.erase(it); + } else { + ++it; + } + } + } + // Clear the earliest retransmission timeouts that are no longer unacked to + // ensure the priority queue doesn't become too large. + while (!retransmission_timeouts_.empty() && + !sent_packet_manager_.IsUnacked( + retransmission_timeouts_.top().sequence_number)) { + retransmission_timeouts_.pop(); } congestion_manager_.OnIncomingAckFrame(incoming_ack, time_of_last_received_packet_); @@ -516,7 +573,7 @@ bool QuicConnection::ValidateAckFrame(const QuicAckFrame& incoming_ack) { incoming_ack.received_info.largest_observed) { DLOG(ERROR) << ENDPOINT << "Peer sent missing packet: " << *incoming_ack.received_info.missing_packets.rbegin() - << " greater than largest observed: " + << " which is greater than largest observed: " << incoming_ack.received_info.largest_observed; return false; } @@ -526,7 +583,7 @@ bool QuicConnection::ValidateAckFrame(const QuicAckFrame& incoming_ack) { received_packet_manager_.least_packet_awaited_by_peer()) { DLOG(ERROR) << ENDPOINT << "Peer sent missing packet: " << *incoming_ack.received_info.missing_packets.begin() - << "smaller than least_packet_awaited_by_peer_: " + << " which is smaller than least_packet_awaited_by_peer_: " << received_packet_manager_.least_packet_awaited_by_peer(); return false; } @@ -542,73 +599,6 @@ bool QuicConnection::ValidateAckFrame(const QuicAckFrame& incoming_ack) { return true; } -void QuicConnection::HandleAckForSentPackets(const QuicAckFrame& incoming_ack, - SequenceNumberSet* acked_packets) { - int retransmitted_packets = 0; - // Go through the packets we have not received an ack for and see if this - // incoming_ack shows they've been seen by the peer. - UnackedPacketMap::iterator it = unacked_packets_.begin(); - while (it != unacked_packets_.end()) { - QuicPacketSequenceNumber sequence_number = it->first; - if (sequence_number > - received_packet_manager_.peer_largest_observed_packet()) { - // These are very new sequence_numbers. - break; - } - RetransmittableFrames* unacked = it->second; - if (!IsAwaitingPacket(incoming_ack.received_info, sequence_number)) { - // Packet was acked, so remove it from our unacked packet list. - DVLOG(1) << ENDPOINT <<"Got an ack for packet " << sequence_number; - acked_packets->insert(sequence_number); - delete unacked; - unacked_packets_.erase(it++); - retransmission_map_.erase(sequence_number); - } else { - // This is a packet which we planned on retransmitting and has not been - // seen at the time of this ack being sent out. See if it's our new - // lowest unacked packet. - DVLOG(1) << ENDPOINT << "still missing packet " << sequence_number; - ++it; - // The peer got packets after this sequence number. This is an explicit - // nack. - RetransmissionMap::iterator retransmission_it = - retransmission_map_.find(sequence_number); - ++(retransmission_it->second.number_nacks); - if (retransmission_it->second.number_nacks >= - kNumberOfNacksBeforeRetransmission && - retransmitted_packets < kMaxRetransmissionsPerAck) { - ++retransmitted_packets; - DVLOG(1) << ENDPOINT << "Trying to retransmit packet " - << sequence_number - << " as it has been nacked 3 or more times."; - // RetransmitPacket will retransmit with a new sequence_number. - RetransmitPacket(sequence_number); - } - } - } -} - -void QuicConnection::HandleAckForSentFecPackets( - const QuicAckFrame& incoming_ack, SequenceNumberSet* acked_packets) { - UnackedPacketMap::iterator it = unacked_fec_packets_.begin(); - while (it != unacked_fec_packets_.end()) { - QuicPacketSequenceNumber sequence_number = it->first; - if (sequence_number > - received_packet_manager_.peer_largest_observed_packet()) { - break; - } - if (!IsAwaitingPacket(incoming_ack.received_info, sequence_number)) { - DVLOG(1) << ENDPOINT << "Got an ack for fec packet: " << sequence_number; - acked_packets->insert(sequence_number); - unacked_fec_packets_.erase(it++); - } else { - DVLOG(1) << ENDPOINT << "Still missing ack for fec packet: " - << sequence_number; - ++it; - } - } -} - void QuicConnection::OnFecData(const QuicFecData& fec) { DCHECK_EQ(IN_FEC_GROUP, last_header_.is_in_fec_group); DCHECK_NE(0u, last_header_.fec_group); @@ -678,9 +668,8 @@ void QuicConnection::OnPacketComplete() { // from unacket_packets_, increasing the least_unacked. const bool last_packet_should_instigate_ack = ShouldLastPacketInstigateAck(); - if ((last_stream_frames_.empty() || - visitor_->OnPacket(self_address_, peer_address_, - last_header_, last_stream_frames_))) { + if (last_stream_frames_.empty() || + visitor_->OnStreamFrames(last_stream_frames_)) { received_packet_manager_.RecordPacketReceived( last_header_, time_of_last_received_packet_); } @@ -738,20 +727,31 @@ bool QuicConnection::ShouldLastPacketInstigateAck() { // the high water mark. if (!last_ack_frames_.empty() && !last_ack_frames_.back().received_info.missing_packets.empty() && - !unacked_packets_.empty()) { - if (unacked_packets_.begin()->first > - *last_ack_frames_.back().received_info.missing_packets.begin()) { - return true; - } + sent_packet_manager_.HasUnackedPackets()) { + return sent_packet_manager_.GetLeastUnackedSentPacket() > + *last_ack_frames_.back().received_info.missing_packets.begin(); } - return false; } void QuicConnection::MaybeSendInResponseToPacket( bool last_packet_should_instigate_ack) { - // TODO(ianswett): Better merge these two blocks to queue up an ack if - // necessary, then either only send the ack or bundle it with other data. + packet_generator_.StartBatchOperations(); + + if (last_packet_should_instigate_ack || + !g_acks_do_not_instigate_acks) { + if (send_ack_in_response_to_packet_) { + SendAck(); + } else if (last_packet_should_instigate_ack) { + // Set the ack alarm for when any retransmittable frame is received. + if (!ack_alarm_->IsSet()) { + ack_alarm_->Set(clock_->ApproximateNow().Add( + congestion_manager_.DelayedAckTime())); + } + } + send_ack_in_response_to_packet_ = !send_ack_in_response_to_packet_; + } + if (!last_ack_frames_.empty()) { // Now the we have received an ack, we might be able to send packets which // are queued locally, or drain streams which are blocked. @@ -766,22 +766,7 @@ void QuicConnection::MaybeSendInResponseToPacket( send_alarm_->Set(time_of_last_received_packet_.Add(delay)); } } - - if (!last_packet_should_instigate_ack) { - return; - } - - if (send_ack_in_response_to_packet_) { - SendAck(); - } else if (!last_stream_frames_.empty()) { - // TODO(alyssar) this case should really be "if the packet contained any - // non-ack frame", rather than "if the packet contained a stream frame" - if (!ack_alarm_->IsSet()) { - ack_alarm_->Set(clock_->ApproximateNow().Add( - congestion_manager_.DefaultRetransmissionTime())); - } - } - send_ack_in_response_to_packet_ = !send_ack_in_response_to_packet_; + packet_generator_.FinishBatchOperations(); } void QuicConnection::SendVersionNegotiationPacket() { @@ -797,15 +782,101 @@ void QuicConnection::SendVersionNegotiationPacket() { delete encrypted; } -QuicConsumedData QuicConnection::SendStreamData(QuicStreamId id, - StringPiece data, - QuicStreamOffset offset, - bool fin) { - return packet_generator_.ConsumeData(id, data, offset, fin); +QuicConsumedData QuicConnection::SendvStreamDataInner( + QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier* notifier) { + // TODO(ianswett): Further improve sending by passing the iovec down + // instead of batching into multiple stream frames in a single packet. + const bool already_in_batch_mode = packet_generator_.InBatchMode(); + packet_generator_.StartBatchOperations(); + + size_t bytes_written = 0; + bool fin_consumed = false; + + for (int i = 0; i < iov_count; ++i) { + bool send_fin = fin && (i == iov_count - 1); + if (!send_fin && iov[i].iov_len == 0) { + LOG(DFATAL) << "Attempt to send empty stream frame"; + } + + StringPiece data(static_cast<char*>(iov[i].iov_base), iov[i].iov_len); + int currentOffset = offset + bytes_written; + QuicConsumedData consumed_data = + packet_generator_.ConsumeData(id, + data, + currentOffset, + send_fin, + notifier); + + DCHECK_LE(consumed_data.bytes_consumed, numeric_limits<uint32>::max()); + bytes_written += consumed_data.bytes_consumed; + fin_consumed = consumed_data.fin_consumed; + // If no bytes were consumed, bail now, because the stream can not write + // more data. + if (consumed_data.bytes_consumed < iov[i].iov_len) { + break; + } + } + // Handle the 0 byte write properly. + if (iov_count == 0) { + DCHECK(fin); + QuicConsumedData consumed_data = packet_generator_.ConsumeData( + id, StringPiece(), offset, fin, NULL); + fin_consumed = consumed_data.fin_consumed; + } + + // Leave the generator in the original batch state. + if (!already_in_batch_mode) { + packet_generator_.FinishBatchOperations(); + } + DCHECK_EQ(already_in_batch_mode, packet_generator_.InBatchMode()); + + return QuicConsumedData(bytes_written, fin_consumed); +} + +QuicConsumedData QuicConnection::SendvStreamData(QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin) { + return SendvStreamDataInner(id, iov, iov_count, offset, fin, NULL); +} + +QuicConsumedData QuicConnection::SendvStreamDataAndNotifyWhenAcked( + QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier::DelegateInterface* delegate) { + if (!fin && iov_count == 0) { + LOG(DFATAL) << "Attempt to send empty stream frame"; + } + // This notifier will be deleted in ProcessAckFrame once it has seen ACKs for + // all the consumed data (or below if no data was consumed). + QuicAckNotifier* notifier = new QuicAckNotifier(delegate); + QuicConsumedData consumed_data = + SendvStreamDataInner(id, iov, iov_count, offset, fin, notifier); + + if (consumed_data.bytes_consumed > 0) { + // If some data was consumed, then the delegate should be registered for + // notification when the data is ACKed. + ack_notifiers_.push_back(notifier); + } else { + // No data was consumed, delete the notifier. + delete notifier; + } + + return consumed_data; } void QuicConnection::SendRstStream(QuicStreamId id, QuicRstStreamErrorCode error) { + LOG(INFO) << "Sending RST_STREAM: " << id << " code: " << error; packet_generator_.AddControlFrame( QuicFrame(new QuicRstStreamFrame(id, error))); } @@ -878,21 +949,30 @@ bool QuicConnection::DoWrite() { DCHECK(!write_blocked_); WriteQueuedPackets(); + IsHandshake pending_handshake = visitor_->HasPendingHandshake() ? + IS_HANDSHAKE : NOT_HANDSHAKE; // Sending queued packets may have caused the socket to become write blocked, // or the congestion manager to prohibit sending. If we've sent everything // we had queued and we're still not blocked, let the visitor know it can // write more. if (CanWrite(NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, - NOT_HANDSHAKE)) { - packet_generator_.StartBatchOperations(); + pending_handshake)) { + const bool already_in_batch_mode = packet_generator_.InBatchMode(); + if (!already_in_batch_mode) { + packet_generator_.StartBatchOperations(); + } bool all_bytes_written = visitor_->OnCanWrite(); - packet_generator_.FinishBatchOperations(); + if (!already_in_batch_mode) { + packet_generator_.FinishBatchOperations(); + } // After the visitor writes, it may have caused the socket to become write // blocked or the congestion manager to prohibit sending, so check again. + pending_handshake = visitor_->HasPendingHandshake() ? IS_HANDSHAKE + : NOT_HANDSHAKE; if (!write_blocked_ && !all_bytes_written && CanWrite(NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, - NOT_HANDSHAKE)) { + pending_handshake)) { // We're not write blocked, but some stream didn't write out all of its // bytes. Register for 'immediate' resumption so we'll keep writing after // other quic connections have had a chance to use the socket. @@ -920,19 +1000,13 @@ bool QuicConnection::ProcessValidatedPacket() { bool QuicConnection::WriteQueuedPackets() { DCHECK(!write_blocked_); - size_t num_queued_packets = queued_packets_.size() + 1; QueuedPacketList::iterator packet_iterator = queued_packets_.begin(); while (!write_blocked_ && packet_iterator != queued_packets_.end()) { - // Ensure that from one iteration of this loop to the next we - // succeeded in sending a packet so we don't infinitely loop. - // TODO(rch): clean up and close the connection if we really hit this. - DCHECK_LT(queued_packets_.size(), num_queued_packets); - num_queued_packets = queued_packets_.size(); if (WritePacket(packet_iterator->encryption_level, packet_iterator->sequence_number, packet_iterator->packet, packet_iterator->retransmittable, - NO_FORCE)) { + packet_iterator->forced)) { packet_iterator = queued_packets_.erase(packet_iterator); } else { // Continue, because some queued packets may still be writable. @@ -946,100 +1020,96 @@ bool QuicConnection::WriteQueuedPackets() { bool QuicConnection::MaybeRetransmitPacketForRTO( QuicPacketSequenceNumber sequence_number) { - DCHECK_EQ(ContainsKey(unacked_packets_, sequence_number), - ContainsKey(retransmission_map_, sequence_number)); - - if (!ContainsKey(unacked_packets_, sequence_number)) { + if (!sent_packet_manager_.IsUnacked(sequence_number)) { DVLOG(2) << ENDPOINT << "alarm fired for " << sequence_number << " but it has been acked or already retransmitted with" - << " different sequence number."; + << " a different sequence number."; // So no extra delay is added for this packet. return true; } - RetransmissionMap::iterator retransmission_it = - retransmission_map_.find(sequence_number); // If the packet hasn't been acked and we're getting truncated acks, ignore // any RTO for packets larger than the peer's largest observed packet; it may // have been received by the peer and just wasn't acked due to the ack frame // running out of space. - if (received_truncated_ack_ && sequence_number > - received_packet_manager_.peer_largest_observed_packet() && + if (received_truncated_ack_ && + sequence_number > GetPeerLargestObservedPacket() && // We allow retransmission of already retransmitted packets so that we // retransmit packets that were retransmissions of the packet with // sequence number < the largest observed field of the truncated ack. - retransmission_it->second.number_retransmissions == 0) { + !sent_packet_manager_.IsRetransmission(sequence_number)) { return false; - } else { - ++stats_.rto_count; - RetransmitPacket(sequence_number); - return true; } + + ++stats_.rto_count; + RetransmitPacket(sequence_number); + return true; } void QuicConnection::RetransmitUnackedPackets( RetransmissionType retransmission_type) { - if (unacked_packets_.empty()) { + SequenceNumberSet unacked_packets = sent_packet_manager_.GetUnackedPackets(); + if (unacked_packets.empty()) { return; } - UnackedPacketMap::iterator next_it = unacked_packets_.begin(); - QuicPacketSequenceNumber end_sequence_number = - unacked_packets_.rbegin()->first; - do { - UnackedPacketMap::iterator current_it = next_it; - ++next_it; + for (SequenceNumberSet::const_iterator unacked_it = unacked_packets.begin(); + unacked_it != unacked_packets.end(); ++unacked_it) { + const RetransmittableFrames& frames = + sent_packet_manager_.GetRetransmittableFrames(*unacked_it); if (retransmission_type == ALL_PACKETS || - current_it->second->encryption_level() == ENCRYPTION_INITIAL) { + frames.encryption_level() == ENCRYPTION_INITIAL) { // TODO(satyamshekhar): Think about congestion control here. // Specifically, about the retransmission count of packets being sent // proactively to achieve 0 (minimal) RTT. - RetransmitPacket(current_it->first); + RetransmitPacket(*unacked_it); } - } while (next_it != unacked_packets_.end() && - next_it->first <= end_sequence_number); + } } void QuicConnection::RetransmitPacket( QuicPacketSequenceNumber sequence_number) { - UnackedPacketMap::iterator unacked_it = - unacked_packets_.find(sequence_number); - RetransmissionMap::iterator retransmission_it = - retransmission_map_.find(sequence_number); - // There should always be an entry corresponding to |sequence_number| in - // both |retransmission_map_| and |unacked_packets_|. Retransmissions due to - // RTO for sequence numbers that are already acked or retransmitted are - // ignored by MaybeRetransmitPacketForRTO. - DCHECK(unacked_it != unacked_packets_.end()); - DCHECK(retransmission_it != retransmission_map_.end()); - RetransmittableFrames* unacked = unacked_it->second; + DCHECK(sent_packet_manager_.IsUnacked(sequence_number)); + // TODO(pwestin): Need to fix potential issue with FEC and a 1 packet // congestion window see b/8331807 for details. congestion_manager_.AbandoningPacket(sequence_number); + const RetransmittableFrames& retransmittable_frames = + sent_packet_manager_.GetRetransmittableFrames(sequence_number); + // Re-packetize the frames with a new sequence number for retransmission. // Retransmitted data packets do not use FEC, even when it's enabled. + // Retransmitted packets use the same sequence number length as the original. + QuicSequenceNumberLength original_sequence_number_length = + sent_packet_manager_.GetSequenceNumberLength(sequence_number); SerializedPacket serialized_packet = - packet_creator_.SerializeAllFrames(unacked->frames()); - RetransmissionInfo retransmission_info(serialized_packet.sequence_number); - retransmission_info.number_retransmissions = - retransmission_it->second.number_retransmissions + 1; - // Remove info with old sequence number. - unacked_packets_.erase(unacked_it); - retransmission_map_.erase(retransmission_it); - DVLOG(1) << ENDPOINT << "Retransmitting unacked packet " << sequence_number - << " as " << serialized_packet.sequence_number; - DCHECK(unacked_packets_.empty() || - unacked_packets_.rbegin()->first < serialized_packet.sequence_number); - unacked_packets_.insert(make_pair(serialized_packet.sequence_number, - unacked)); - retransmission_map_.insert(make_pair(serialized_packet.sequence_number, - retransmission_info)); - SendOrQueuePacket(unacked->encryption_level(), + packet_creator_.ReserializeAllFrames(retransmittable_frames.frames(), + original_sequence_number_length); + + // A notifier may be waiting to hear about ACKs for the original sequence + // number. Inform them that the sequence number has changed. + for (AckNotifierList::iterator notifier_it = ack_notifiers_.begin(); + notifier_it != ack_notifiers_.end(); ++notifier_it) { + (*notifier_it)->UpdateSequenceNumber(sequence_number, + serialized_packet.sequence_number); + } + + DLOG(INFO) << ENDPOINT << "Retransmitting " << sequence_number << " as " + << serialized_packet.sequence_number; + if (debug_visitor_) { + debug_visitor_->OnPacketRetransmitted(sequence_number, + serialized_packet.sequence_number); + } + sent_packet_manager_.OnRetransmittedPacket(sequence_number, + serialized_packet.sequence_number); + + SendOrQueuePacket(retransmittable_frames.encryption_level(), serialized_packet.sequence_number, serialized_packet.packet, serialized_packet.entropy_hash, - HAS_RETRANSMITTABLE_DATA); + HAS_RETRANSMITTABLE_DATA, + HasForcedFrames(serialized_packet.retransmittable_frames)); } bool QuicConnection::CanWrite(Retransmission retransmission, @@ -1067,30 +1137,22 @@ bool QuicConnection::CanWrite(Retransmission retransmission, return true; } -bool QuicConnection::IsRetransmission( - QuicPacketSequenceNumber sequence_number) { - RetransmissionMap::iterator it = retransmission_map_.find(sequence_number); - return it != retransmission_map_.end() && - it->second.number_retransmissions > 0; -} - void QuicConnection::SetupRetransmission( QuicPacketSequenceNumber sequence_number, EncryptionLevel level) { - RetransmissionMap::iterator it = retransmission_map_.find(sequence_number); - if (it == retransmission_map_.end()) { + if (!sent_packet_manager_.IsUnacked(sequence_number)) { DVLOG(1) << ENDPOINT << "Will not retransmit packet " << sequence_number; return; } - - RetransmissionInfo retransmission_info = it->second; + size_t retransmission_count = + sent_packet_manager_.GetRetransmissionCount(sequence_number); // TODO(rch): consider using a much smaller retransmisison_delay // for the ENCRYPTION_NONE packets. size_t effective_retransmission_count = - level == ENCRYPTION_NONE ? 0 : retransmission_info.number_retransmissions; + level == ENCRYPTION_NONE ? 0 : retransmission_count; QuicTime::Delta retransmission_delay = congestion_manager_.GetRetransmissionDelay( - unacked_packets_.size(), + sent_packet_manager_.GetNumUnackedPackets(), effective_retransmission_count); retransmission_timeouts_.push(RetransmissionTime( @@ -1111,7 +1173,6 @@ void QuicConnection::SetupRetransmission( void QuicConnection::SetupAbandonFecTimer( QuicPacketSequenceNumber sequence_number) { - DCHECK(ContainsKey(unacked_fec_packets_, sequence_number)); QuicTime::Delta retransmission_delay = QuicTime::Delta::FromMilliseconds( congestion_manager_.DefaultRetransmissionTime().ToMilliseconds() * 3); @@ -1121,21 +1182,6 @@ void QuicConnection::SetupAbandonFecTimer( true)); } -void QuicConnection::DropPacket(QuicPacketSequenceNumber sequence_number) { - UnackedPacketMap::iterator unacked_it = - unacked_packets_.find(sequence_number); - // Packet was not meant to be retransmitted. - if (unacked_it == unacked_packets_.end()) { - DCHECK(!ContainsKey(retransmission_map_, sequence_number)); - return; - } - // Delete the unacked packet. - delete unacked_it->second; - unacked_packets_.erase(unacked_it); - retransmission_map_.erase(sequence_number); - return; -} - bool QuicConnection::WritePacket(EncryptionLevel level, QuicPacketSequenceNumber sequence_number, QuicPacket* packet, @@ -1154,15 +1200,20 @@ bool QuicConnection::WritePacket(EncryptionLevel level, level == ENCRYPTION_NONE) { // Drop packets that are NULL encrypted since the peer won't accept them // anymore. - DLOG(INFO) << ENDPOINT << "Dropped packet: " << sequence_number + DLOG(INFO) << ENDPOINT << "Dropping packet: " << sequence_number << " since the packet is NULL encrypted."; - DropPacket(sequence_number); + sent_packet_manager_.DiscardPacket(sequence_number); delete packet; return true; } - Retransmission retransmission = IsRetransmission(sequence_number) ? + Retransmission retransmission = + sent_packet_manager_.IsRetransmission(sequence_number) ? IS_RETRANSMISSION : NOT_RETRANSMISSION; + // TODO(wtc): use the same logic that is used in the packet generator. + // Namely, a packet is a handshake if it contains a stream frame for the + // crypto stream. It should be possible to look at the RetransmittableFrames + // in the SerializedPacket to determine this for a packet. IsHandshake handshake = level == ENCRYPTION_NONE ? IS_HANDSHAKE : NOT_HANDSHAKE; @@ -1174,6 +1225,12 @@ bool QuicConnection::WritePacket(EncryptionLevel level, scoped_ptr<QuicEncryptedPacket> encrypted( framer_.EncryptPacket(level, sequence_number, *packet)); + if (encrypted.get() == NULL) { + LOG(DFATAL) << ENDPOINT << "Failed to encrypt packet number " + << sequence_number; + CloseConnection(QUIC_ENCRYPTION_FAILURE, false); + return false; + } DLOG(INFO) << ENDPOINT << "Sending packet number " << sequence_number << " : " << (packet->is_fec_packet() ? "FEC " : (retransmittable == HAS_RETRANSMITTABLE_DATA @@ -1205,7 +1262,11 @@ bool QuicConnection::WritePacket(EncryptionLevel level, // If the socket buffers the the data, then the packet should not // be queued and sent again, which would result in an unnecessary // duplicate packet being sent. - return helper_->IsWriteBlockedDataBuffered(); + if (helper_->IsWriteBlockedDataBuffered()) { + delete packet; + return true; + } + return false; } // We can't send an error as the socket is presumably borked. CloseConnection(QUIC_PACKET_WRITE_ERROR, false); @@ -1221,8 +1282,15 @@ bool QuicConnection::WritePacket(EncryptionLevel level, SetupAbandonFecTimer(sequence_number); } + // TODO(ianswett): Change the sequence number length and other packet creator + // options by a more explicit API than setting a struct value directly. + packet_creator_.UpdateSequenceNumberLength( + received_packet_manager_.least_packet_awaited_by_peer(), + congestion_manager_.BandwidthEstimate().ToBytesPerPeriod( + congestion_manager_.SmoothedRtt())); + congestion_manager_.SentPacket(sequence_number, now, packet->length(), - retransmission); + retransmission, retransmittable); stats_.bytes_sent += encrypted->length(); ++stats_.packets_sent; @@ -1252,44 +1320,49 @@ int QuicConnection::WritePacketToWire(QuicPacketSequenceNumber sequence_number, bool QuicConnection::OnSerializedPacket( const SerializedPacket& serialized_packet) { - if (serialized_packet.retransmittable_frames != NULL) { - DCHECK(unacked_packets_.empty() || - unacked_packets_.rbegin()->first < - serialized_packet.sequence_number); - // Retransmitted frames will be sent with the same encryption level as the - // original. + if (serialized_packet.retransmittable_frames) { serialized_packet.retransmittable_frames->set_encryption_level( encryption_level_); - unacked_packets_.insert( - make_pair(serialized_packet.sequence_number, - serialized_packet.retransmittable_frames)); - // All unacked packets might be retransmitted. - retransmission_map_.insert( - make_pair(serialized_packet.sequence_number, - RetransmissionInfo(serialized_packet.sequence_number))); - } else if (serialized_packet.packet->is_fec_packet()) { - unacked_fec_packets_.insert(make_pair( - serialized_packet.sequence_number, - serialized_packet.retransmittable_frames)); } + sent_packet_manager_.OnSerializedPacket(serialized_packet); return SendOrQueuePacket(encryption_level_, serialized_packet.sequence_number, serialized_packet.packet, serialized_packet.entropy_hash, serialized_packet.retransmittable_frames != NULL ? HAS_RETRANSMITTABLE_DATA : - NO_RETRANSMITTABLE_DATA); + NO_RETRANSMITTABLE_DATA, + HasForcedFrames( + serialized_packet.retransmittable_frames)); +} + +QuicPacketSequenceNumber QuicConnection::GetPeerLargestObservedPacket() { + return received_packet_manager_.peer_largest_observed_packet(); +} + +QuicPacketSequenceNumber QuicConnection::GetNextPacketSequenceNumber() { + return packet_creator_.sequence_number() + 1; +} + +void QuicConnection::OnPacketNacked(QuicPacketSequenceNumber sequence_number, + size_t nack_count) { + if (nack_count >= kNumberOfNacksBeforeRetransmission && + retransmitted_nacked_packet_count_ < kMaxRetransmissionsPerAck) { + ++retransmitted_nacked_packet_count_; + RetransmitPacket(sequence_number); + } } bool QuicConnection::SendOrQueuePacket(EncryptionLevel level, QuicPacketSequenceNumber sequence_number, QuicPacket* packet, QuicPacketEntropyHash entropy_hash, - HasRetransmittableData retransmittable) { + HasRetransmittableData retransmittable, + Force forced) { sent_entropy_manager_.RecordPacketEntropyHash(sequence_number, entropy_hash); - if (!WritePacket(level, sequence_number, packet, retransmittable, NO_FORCE)) { + if (!WritePacket(level, sequence_number, packet, retransmittable, forced)) { queued_packets_.push_back(QueuedPacket(sequence_number, packet, level, - retransmittable)); + retransmittable, forced)); return false; } return true; @@ -1305,14 +1378,7 @@ bool QuicConnection::ShouldSimulateLostPacket() { } void QuicConnection::UpdateSentPacketInfo(SentPacketInfo* sent_info) { - if (!unacked_packets_.empty()) { - sent_info->least_unacked = unacked_packets_.begin()->first; - } else { - // If there are no unacked packets, set the least unacked packet to - // sequence_number() + 1 since that will be the sequence number of this - // ack packet whenever it is sent. - sent_info->least_unacked = packet_creator_.sequence_number() + 1; - } + sent_info->least_unacked = sent_packet_manager_.GetLeastUnackedSentPacket(); sent_info->entropy_hash = sent_entropy_manager_.EntropyHash( sent_info->least_unacked - 1); } @@ -1336,13 +1402,12 @@ void QuicConnection::SendAck() { void QuicConnection::MaybeAbandonFecPacket( QuicPacketSequenceNumber sequence_number) { - if (!ContainsKey(unacked_fec_packets_, sequence_number)) { + if (!sent_packet_manager_.IsFecUnacked(sequence_number)) { DVLOG(2) << ENDPOINT << "no need to abandon fec packet: " << sequence_number << "; it's already acked'"; return; } congestion_manager_.AbandoningPacket(sequence_number); - // TODO(satyashekhar): Should this decrease the congestion window? } QuicTime QuicConnection::OnRetransmissionTimeout() { @@ -1509,36 +1574,6 @@ void QuicConnection::SendConnectionClose(QuicErrorCode error) { SendConnectionCloseWithDetails(error, string()); } -void QuicConnection::SendConnectionClosePacket(QuicErrorCode error, - const string& details) { - DLOG(INFO) << ENDPOINT << "Force closing with error " - << QuicUtils::ErrorToString(error) << " (" << error << ") " - << details; - QuicConnectionCloseFrame frame; - frame.error_code = error; - frame.error_details = details; - UpdateSentPacketInfo(&frame.ack_frame.sent_info); - received_packet_manager_.UpdateReceivedPacketInfo( - &frame.ack_frame.received_info, clock_->ApproximateNow()); - - SerializedPacket serialized_packet = - packet_creator_.SerializeConnectionClose(&frame); - - // We need to update the sent entropy hash for all sent packets. - sent_entropy_manager_.RecordPacketEntropyHash( - serialized_packet.sequence_number, - serialized_packet.entropy_hash); - - if (!WritePacket(encryption_level_, - serialized_packet.sequence_number, - serialized_packet.packet, - serialized_packet.retransmittable_frames != NULL ? - HAS_RETRANSMITTABLE_DATA : NO_RETRANSMITTABLE_DATA, - FORCE)) { - delete serialized_packet.packet; - } -} - void QuicConnection::SendConnectionCloseWithDetails(QuicErrorCode error, const string& details) { if (!write_blocked_) { @@ -1547,6 +1582,21 @@ void QuicConnection::SendConnectionCloseWithDetails(QuicErrorCode error, CloseConnection(error, false); } +void QuicConnection::SendConnectionClosePacket(QuicErrorCode error, + const string& details) { + DLOG(INFO) << ENDPOINT << "Force closing with error " + << QuicUtils::ErrorToString(error) << " (" << error << ") " + << details; + QuicConnectionCloseFrame* frame = new QuicConnectionCloseFrame(); + frame->error_code = error; + frame->error_details = details; + UpdateSentPacketInfo(&frame->ack_frame.sent_info); + received_packet_manager_.UpdateReceivedPacketInfo( + &frame->ack_frame.received_info, clock_->ApproximateNow()); + packet_generator_.AddControlFrame(QuicFrame(frame)); + Flush(); +} + void QuicConnection::CloseConnection(QuicErrorCode error, bool from_peer) { DCHECK(connected_); connected_ = false; @@ -1584,6 +1634,14 @@ void QuicConnection::CloseFecGroupsBefore( } } +void QuicConnection::Flush() { + if (!packet_generator_.InBatchMode()) { + return; + } + packet_generator_.FinishBatchOperations(); + packet_generator_.StartBatchOperations(); +} + bool QuicConnection::HasQueuedData() const { return !queued_packets_.empty() || packet_generator_.HasQueuedFrames(); } diff --git a/chromium/net/quic/quic_connection.h b/chromium/net/quic/quic_connection.h index 41172f7d46d..15249097727 100644 --- a/chromium/net/quic/quic_connection.h +++ b/chromium/net/quic/quic_connection.h @@ -24,18 +24,21 @@ #include <vector> #include "base/containers/hash_tables.h" +#include "net/base/iovec.h" #include "net/base/ip_endpoint.h" #include "net/base/linked_hash_map.h" #include "net/quic/congestion_control/quic_congestion_manager.h" +#include "net/quic/quic_ack_notifier.h" #include "net/quic/quic_alarm.h" #include "net/quic/quic_blocked_writer_interface.h" +#include "net/quic/quic_connection_stats.h" #include "net/quic/quic_framer.h" #include "net/quic/quic_packet_creator.h" #include "net/quic/quic_packet_generator.h" #include "net/quic/quic_protocol.h" #include "net/quic/quic_received_packet_manager.h" #include "net/quic/quic_sent_entropy_manager.h" -#include "net/quic/quic_stats.h" +#include "net/quic/quic_sent_packet_manager.h" namespace net { @@ -48,6 +51,8 @@ namespace test { class QuicConnectionPeer; } // namespace test +// Class that receives callbacks from the connection when frames are received +// and when other interesting events happen. class NET_EXPORT_PRIVATE QuicConnectionVisitorInterface { public: virtual ~QuicConnectionVisitorInterface() {} @@ -56,10 +61,7 @@ class NET_EXPORT_PRIVATE QuicConnectionVisitorInterface { // should determine if all frames will be accepted, and return true if so. // If any frames can't be processed or buffered, none of the data should // be used, and the callee should return false. - virtual bool OnPacket(const IPEndPoint& self_address, - const IPEndPoint& peer_address, - const QuicPacketHeader& header, - const std::vector<QuicStreamFrame>& frame) = 0; + virtual bool OnStreamFrames(const std::vector<QuicStreamFrame>& frames) = 0; // Called when the stream is reset by the peer. virtual void OnRstStream(const QuicRstStreamFrame& frame) = 0; @@ -72,13 +74,16 @@ class NET_EXPORT_PRIVATE QuicConnectionVisitorInterface { virtual void ConnectionClose(QuicErrorCode error, bool from_peer) = 0; - // Called when packets are acked by the peer. - virtual void OnAck(const SequenceNumberSet& acked_packets) = 0; + // Called once a specific QUIC version is agreed by both endpoints. + virtual void OnSuccessfulVersionNegotiation(const QuicVersion& version) = 0; // Called when a blocked socket becomes writable. If all pending bytes for // this visitor are consumed by the connection successfully this should // return true, otherwise it should return false. virtual bool OnCanWrite() = 0; + + // Called to ask if any handshake messages are pending in this visitor. + virtual bool HasPendingHandshake() const = 0; }; // Interface which gets callbacks from the QuicConnection at interesting @@ -95,6 +100,12 @@ class NET_EXPORT_PRIVATE QuicConnectionDebugVisitorInterface const QuicEncryptedPacket& packet, int rv) = 0; + // Called when the contents of a packet have been retransmitted as + // a new packet. + virtual void OnPacketRetransmitted( + QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number) = 0; + // Called when a packet has been received, but before it is // validated or parsed. virtual void OnPacketReceived(const IPEndPoint& self_address, @@ -177,7 +188,8 @@ class NET_EXPORT_PRIVATE QuicConnectionHelperInterface { class NET_EXPORT_PRIVATE QuicConnection : public QuicFramerVisitorInterface, public QuicBlockedWriterInterface, - public QuicPacketGenerator::DelegateInterface { + public QuicPacketGenerator::DelegateInterface, + public QuicSentPacketManager::HelperInterface { public: enum Force { NO_FORCE, @@ -198,17 +210,28 @@ class NET_EXPORT_PRIVATE QuicConnection QuicVersion version); virtual ~QuicConnection(); - static void DeleteEnclosedFrame(QuicFrame* frame); - - // Send the data payload to the peer. + // Send the data in |iov| to the peer in as few packets as possible. // Returns a pair with the number of bytes consumed from data, and a boolean // indicating if the fin bit was consumed. This does not indicate the data // has been sent on the wire: it may have been turned into a packet and queued // if the socket was unexpectedly blocked. - QuicConsumedData SendStreamData(QuicStreamId id, - base::StringPiece data, - QuicStreamOffset offset, - bool fin); + QuicConsumedData SendvStreamData(QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin); + + // Same as SendvStreamData, except the provided delegate will be informed + // once ACKs have been received for all the packets written. + // The |delegate| is not owned by the QuicConnection and must outlive it. + QuicConsumedData SendvStreamDataAndNotifyWhenAcked( + QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier::DelegateInterface* delegate); + // Send a stream reset frame to the peer. virtual void SendRstStream(QuicStreamId id, QuicRstStreamErrorCode error); @@ -225,7 +248,7 @@ class NET_EXPORT_PRIVATE QuicConnection virtual void SendConnectionCloseWithDetails(QuicErrorCode error, const std::string& details); // Notifies the visitor of the close and marks the connection as disconnected. - void CloseConnection(QuicErrorCode error, bool from_peer); + virtual void CloseConnection(QuicErrorCode error, bool from_peer) OVERRIDE; virtual void SendGoAway(QuicErrorCode error, QuicStreamId last_good_stream_id, const std::string& reason); @@ -288,6 +311,12 @@ class NET_EXPORT_PRIVATE QuicConnection virtual QuicCongestionFeedbackFrame* CreateFeedbackFrame() OVERRIDE; virtual bool OnSerializedPacket(const SerializedPacket& packet) OVERRIDE; + // QuicSentPacketManager::HelperInterface + virtual QuicPacketSequenceNumber GetPeerLargestObservedPacket() OVERRIDE; + virtual QuicPacketSequenceNumber GetNextPacketSequenceNumber() OVERRIDE; + virtual void OnPacketNacked(QuicPacketSequenceNumber sequence_number, + size_t nack_count) OVERRIDE; + // Accessors void set_visitor(QuicConnectionVisitorInterface* visitor) { visitor_ = visitor; @@ -320,6 +349,10 @@ class NET_EXPORT_PRIVATE QuicConnection // Testing only. size_t NumQueuedPackets() const { return queued_packets_.size(); } + // Flush any queued frames immediately. Preserves the batch write mode and + // does nothing if there are no pending frames. + void Flush(); + // Returns true if the connection has queued packets or frames. bool HasQueuedData() const; @@ -381,19 +414,26 @@ class NET_EXPORT_PRIVATE QuicConnection const QuicDecrypter* decrypter() const; const QuicDecrypter* alternative_decrypter() const; + bool is_server() const { return is_server_; } + + static bool g_acks_do_not_instigate_acks; + protected: // Send a packet to the peer using encryption |level|. If |sequence_number| // is present in the |retransmission_map_|, then contents of this packet will // be retransmitted with a new sequence number if it's not acked by the peer. // Deletes |packet| via WritePacket call or transfers ownership to - // QueuedPacket, ultimately deleted via WritePacket. Also, it updates the + // QueuedPacket, ultimately deleted via WritePacket. Updates the // entropy map corresponding to |sequence_number| using |entropy_hash|. + // |retransmittable| is supplied to the congestion manager, and when |forced| + // is true, it bypasses the congestion manager. // TODO(wtc): none of the callers check the return value. virtual bool SendOrQueuePacket(EncryptionLevel level, QuicPacketSequenceNumber sequence_number, QuicPacket* packet, QuicPacketEntropyHash entropy_hash, - HasRetransmittableData retransmittable); + HasRetransmittableData retransmittable, + Force forced); // Writes the given packet to socket, encrypted with |level|, with the help // of helper. Returns true on successful write, false otherwise. However, @@ -429,33 +469,48 @@ class NET_EXPORT_PRIVATE QuicConnection private: friend class test::QuicConnectionPeer; + // Inner helper function to SendvStreamData and + // SendvStreamDataAndNotifyWhenAcked. + QuicConsumedData SendvStreamDataInner(QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier *notifier); + // Packets which have not been written to the wire. // Owns the QuicPacket* packet. struct QueuedPacket { QueuedPacket(QuicPacketSequenceNumber sequence_number, QuicPacket* packet, EncryptionLevel level, - HasRetransmittableData retransmittable) + HasRetransmittableData retransmittable, + Force forced) : sequence_number(sequence_number), packet(packet), encryption_level(level), - retransmittable(retransmittable) { + retransmittable(retransmittable), + forced(forced) { } QuicPacketSequenceNumber sequence_number; QuicPacket* packet; const EncryptionLevel encryption_level; HasRetransmittableData retransmittable; + Force forced; }; struct RetransmissionInfo { - explicit RetransmissionInfo(QuicPacketSequenceNumber sequence_number) + RetransmissionInfo(QuicPacketSequenceNumber sequence_number, + QuicSequenceNumberLength sequence_number_length) : sequence_number(sequence_number), + sequence_number_length(sequence_number_length), number_nacks(0), number_retransmissions(0) { } QuicPacketSequenceNumber sequence_number; + QuicSequenceNumberLength sequence_number_length; size_t number_nacks; size_t number_retransmissions; }; @@ -484,15 +539,12 @@ class NET_EXPORT_PRIVATE QuicConnection }; typedef std::list<QueuedPacket> QueuedPacketList; - typedef linked_hash_map<QuicPacketSequenceNumber, - RetransmittableFrames*> UnackedPacketMap; typedef std::map<QuicFecGroupNumber, QuicFecGroup*> FecGroupMap; - typedef base::hash_map<QuicPacketSequenceNumber, - RetransmissionInfo> RetransmissionMap; typedef std::priority_queue<RetransmissionTime, std::vector<RetransmissionTime>, RetransmissionTimeComparator> RetransmissionTimeouts; + typedef std::list<QuicAckNotifier*> AckNotifierList; // Sends a version negotiation packet to the peer. void SendVersionNegotiationPacket(); @@ -510,6 +562,12 @@ class NET_EXPORT_PRIVATE QuicConnection // Returns false if the socket has become blocked. bool DoWrite(); + // Calculates the smallest sequence number length that can also represent four + // times the maximum of the congestion window and the difference between the + // least_packet_awaited_by_peer_ and |sequence_number|. + QuicSequenceNumberLength CalculateSequenceNumberLength( + QuicPacketSequenceNumber sequence_number); + // Drop packet corresponding to |sequence_number| by deleting entries from // |unacked_packets_| and |retransmission_map_|, if present. We need to drop // all packets with encryption level NONE after the default level has been set @@ -533,11 +591,6 @@ class NET_EXPORT_PRIVATE QuicConnection void ProcessAckFrame(const QuicAckFrame& incoming_ack); - void HandleAckForSentPackets(const QuicAckFrame& incoming_ack, - SequenceNumberSet* acked_packets); - void HandleAckForSentFecPackets(const QuicAckFrame& incoming_ack, - SequenceNumberSet* acked_packets); - // Update the |sent_info| for an outgoing ack. void UpdateSentPacketInfo(SentPacketInfo* sent_info); @@ -576,6 +629,9 @@ class NET_EXPORT_PRIVATE QuicConnection std::vector<QuicCongestionFeedbackFrame> last_congestion_frames_; std::vector<QuicRstStreamFrame> last_rst_frames_; std::vector<QuicGoAwayFrame> last_goaway_frames_; + // Then number of packets retransmitted because of nacks + // while processed the current ack frame. + size_t retransmitted_nacked_packet_count_; QuicCongestionFeedbackFrame outgoing_congestion_feedback_; @@ -583,16 +639,6 @@ class NET_EXPORT_PRIVATE QuicConnection // Largest sequence sent by the peer which had an ack frame (latest ack info). QuicPacketSequenceNumber largest_seen_packet_with_ack_; - // When new packets are created which may be retransmitted, they are added - // to this map, which contains owning pointers to the contained frames. - UnackedPacketMap unacked_packets_; - - // Pending fec packets that have not been acked yet. These packets need to be - // cleared out of the cgst_window after a timeout since FEC packets are never - // retransmitted. - // Ask: What should be the timeout for these packets? - UnackedPacketMap unacked_fec_packets_; - // Collection of packets which were received before encryption was // established, but which could not be decrypted. We buffer these on // the assumption that they could not be processed because they were @@ -607,9 +653,6 @@ class NET_EXPORT_PRIVATE QuicConnection // contains all packets that have been retransmitted x times. RetransmissionTimeouts retransmission_timeouts_; - // Map from sequence number to the retransmission info. - RetransmissionMap retransmission_map_; - // True while OnRetransmissionTimeout is running to prevent // SetRetransmissionAlarm from being called erroneously. bool handling_retransmission_timeout_; @@ -663,6 +706,10 @@ class NET_EXPORT_PRIVATE QuicConnection // as well as collecting and generating congestion feedback. QuicCongestionManager congestion_manager_; + // Sent packet manager which tracks the status of packets sent by this + // connection. + QuicSentPacketManager sent_packet_manager_; + // The state of connection in version negotiation finite state machine. QuicVersionNegotiationState version_negotiation_state_; @@ -684,6 +731,12 @@ class NET_EXPORT_PRIVATE QuicConnection // This is checked later on validating a data or version negotiation packet. bool address_migrating_; + // On every ACK frame received by this connection, all the ack_notifiers_ will + // be told which sequeunce numbers were ACKed. + // Once a given QuicAckNotifier has seen all the sequence numbers it is + // interested in, it will be deleted, and removed from this list. + AckNotifierList ack_notifiers_; + DISALLOW_COPY_AND_ASSIGN(QuicConnection); }; diff --git a/chromium/net/quic/quic_connection_helper.cc b/chromium/net/quic/quic_connection_helper.cc index c3e796fa321..703151b704f 100644 --- a/chromium/net/quic/quic_connection_helper.cc +++ b/chromium/net/quic/quic_connection_helper.cc @@ -6,6 +6,7 @@ #include "base/location.h" #include "base/logging.h" +#include "base/metrics/sparse_histogram.h" #include "base/task_runner.h" #include "base/time/time.h" #include "net/base/io_buffer.h" @@ -125,6 +126,9 @@ int QuicConnectionHelper::WritePacketToWire( if (rv >= 0) { *error = 0; } else { + if (rv != ERR_IO_PENDING) { + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.QuicSession.WriteError", -rv); + } *error = rv; rv = -1; } diff --git a/chromium/net/quic/quic_connection_helper_test.cc b/chromium/net/quic/quic_connection_helper_test.cc index 4822ea67541..aaa692b548c 100644 --- a/chromium/net/quic/quic_connection_helper_test.cc +++ b/chromium/net/quic/quic_connection_helper_test.cc @@ -9,6 +9,7 @@ #include "net/base/net_errors.h" #include "net/quic/crypto/quic_decrypter.h" #include "net/quic/crypto/quic_encrypter.h" +#include "net/quic/quic_connection.h" #include "net/quic/test_tools/mock_clock.h" #include "net/quic/test_tools/quic_connection_peer.h" #include "net/quic/test_tools/quic_test_utils.h" @@ -18,6 +19,8 @@ #include "testing/gtest/include/gtest/gtest.h" using testing::_; +using testing::AnyNumber; +using testing::Return; namespace net { namespace test { @@ -120,7 +123,14 @@ class QuicConnectionHelperTest : public ::testing::Test { &random_generator_, socket_.get()); send_algorithm_ = new testing::StrictMock<MockSendAlgorithm>(); EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, _, _, _)). - WillRepeatedly(testing::Return(QuicTime::Delta::Zero())); + WillRepeatedly(Return(QuicTime::Delta::Zero())); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillRepeatedly( + Return(QuicBandwidth::FromKBitsPerSecond(100))); + EXPECT_CALL(*send_algorithm_, SmoothedRtt()).WillRepeatedly( + Return(QuicTime::Delta::FromMilliseconds(100))); + ON_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)) + .WillByDefault(Return(true)); + EXPECT_CALL(visitor_, HasPendingHandshake()).Times(AnyNumber()); connection_.reset(new TestConnection(guid_, IPEndPoint(), helper_)); connection_->set_visitor(&visitor_); connection_->SetSendAlgorithm(send_algorithm_); @@ -199,6 +209,7 @@ class QuicConnectionHelperTest : public ::testing::Test { header_.public_header.guid = guid_; header_.public_header.reset_flag = false; header_.public_header.version_flag = true; + header_.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header_.packet_sequence_number = sequence_number; header_.entropy_flag = false; header_.fec_flag = false; @@ -304,17 +315,20 @@ TEST_F(QuicConnectionHelperTest, TestRetransmission) { Initialize(); EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillRepeatedly( - testing::Return(QuicTime::Delta::Zero())); + Return(QuicTime::Delta::Zero())); QuicTime::Delta kDefaultRetransmissionTime = QuicTime::Delta::FromMilliseconds(500); QuicTime start = clock_.ApproximateNow(); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION, _)); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, _)); + // Send a packet. - connection_->SendStreamData(1, kData, 0, false); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, IS_RETRANSMISSION)); + struct iovec iov = {const_cast<char*>(kData), + static_cast<size_t>(strlen(kData))}; + connection_->SendvStreamData(1, &iov, 1, 0, false); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, IS_RETRANSMISSION, _)); // Since no ack was received, the retransmission alarm will fire and // retransmit it. runner_->RunNextTask(); @@ -331,17 +345,20 @@ TEST_F(QuicConnectionHelperTest, TestMultipleRetransmission) { Initialize(); EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillRepeatedly( - testing::Return(QuicTime::Delta::Zero())); + Return(QuicTime::Delta::Zero())); QuicTime::Delta kDefaultRetransmissionTime = QuicTime::Delta::FromMilliseconds(500); QuicTime start = clock_.ApproximateNow(); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION, _)); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, _)); + // Send a packet. - connection_->SendStreamData(1, kData, 0, false); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, IS_RETRANSMISSION)); + struct iovec iov = {const_cast<char*>(kData), + static_cast<size_t>(strlen(kData))}; + connection_->SendvStreamData(1, &iov, 1, 0, false); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, IS_RETRANSMISSION, _)); // Since no ack was received, the retransmission alarm will fire and // retransmit it. runner_->RunNextTask(); @@ -351,7 +368,7 @@ TEST_F(QuicConnectionHelperTest, TestMultipleRetransmission) { // Since no ack was received, the retransmission alarm will fire and // retransmit it. - EXPECT_CALL(*send_algorithm_, SentPacket(_, 3, _, IS_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 3, _, IS_RETRANSMISSION, _)); EXPECT_CALL(*send_algorithm_, AbandoningPacket(2, _)); runner_->RunNextTask(); @@ -370,7 +387,10 @@ TEST_F(QuicConnectionHelperTest, InitialTimeout) { EXPECT_EQ(base::TimeDelta::FromSeconds(kDefaultInitialTimeoutSecs), runner_->GetPostedTasks().front().delay); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA)); + EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillOnce( + Return(QuicTime::Delta::FromMicroseconds(1))); // After we run the next task, we should close the connection. EXPECT_CALL(visitor_, ConnectionClose(QUIC_CONNECTION_TIMED_OUT, false)); @@ -397,7 +417,7 @@ TEST_F(QuicConnectionHelperTest, WritePacketToWireAsync) { AddWrite(ASYNC, ConstructClosePacket(1, 0)); Initialize(); - EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(testing::Return(true)); + EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(Return(true)); int error = 0; EXPECT_EQ(-1, helper_->WritePacketToWire(*GetWrite(0), &error)); EXPECT_EQ(ERR_IO_PENDING, error); @@ -417,7 +437,8 @@ TEST_F(QuicConnectionHelperTest, TimeoutAfterSend) { // kDefaultInitialTimeoutSecs. clock_.AdvanceTime(QuicTime::Delta::FromMicroseconds(5000)); EXPECT_EQ(5000u, clock_.ApproximateNow().Subtract(start).ToMicroseconds()); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, + SentPacket(_, 1, _, NOT_RETRANSMISSION, NO_RETRANSMITTABLE_DATA)); // Send an ack so we don't set the retransmission alarm. connection_->SendAck(); @@ -433,7 +454,10 @@ TEST_F(QuicConnectionHelperTest, TimeoutAfterSend) { // This time, we should time out. EXPECT_CALL(visitor_, ConnectionClose(QUIC_CONNECTION_TIMED_OUT, !kFromPeer)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA)); + EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillOnce( + Return(QuicTime::Delta::FromMicroseconds(1))); runner_->RunNextTask(); EXPECT_EQ(kDefaultInitialTimeoutSecs * 1000000 + 5000, clock_.ApproximateNow().Subtract( @@ -448,23 +472,26 @@ TEST_F(QuicConnectionHelperTest, SendSchedulerDelayThenSend) { // Test that if we send a packet with a delay, it ends up queued. EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillRepeatedly( - testing::Return(QuicTime::Delta::Zero())); + Return(QuicTime::Delta::Zero())); EXPECT_CALL( *send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( - testing::Return(QuicTime::Delta::FromMicroseconds(1))); + Return(QuicTime::Delta::FromMicroseconds(1))); QuicPacket* packet = ConstructRawDataPacket(1); - connection_->SendOrQueuePacket( - ENCRYPTION_NONE, 1, packet, 0, HAS_RETRANSMITTABLE_DATA); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + connection_->SendOrQueuePacket(ENCRYPTION_NONE, 1, packet, 0, + HAS_RETRANSMITTABLE_DATA, + QuicConnection::NO_FORCE); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION, + _)); EXPECT_EQ(1u, connection_->NumQueuedPackets()); // Advance the clock to fire the alarm, and configure the scheduler // to permit the packet to be sent. EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillRepeatedly( - testing::Return(QuicTime::Delta::Zero())); - EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(testing::Return(true)); + Return(QuicTime::Delta::Zero())); + EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(Return(true)); + EXPECT_CALL(visitor_, HasPendingHandshake()).Times(AnyNumber()); runner_->RunNextTask(); EXPECT_EQ(0u, connection_->NumQueuedPackets()); EXPECT_TRUE(AtEof()); diff --git a/chromium/net/quic/quic_connection_logger.cc b/chromium/net/quic/quic_connection_logger.cc index 3405cd0c764..5195f4391a3 100644 --- a/chromium/net/quic/quic_connection_logger.cc +++ b/chromium/net/quic/quic_connection_logger.cc @@ -10,6 +10,9 @@ #include "base/strings/string_number_conversions.h" #include "base/values.h" #include "net/base/net_log.h" +#include "net/quic/crypto/crypto_handshake.h" + +using std::string; namespace net { @@ -43,6 +46,18 @@ base::Value* NetLogQuicPacketSentCallback( return dict; } +base::Value* NetLogQuicPacketRetransmittedCallback( + QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("old_packet_sequence_number", + base::Uint64ToString(old_sequence_number)); + dict->SetString("new_packet_sequence_number", + base::Uint64ToString(new_sequence_number)); + return dict; +} + base::Value* NetLogQuicPacketHeaderCallback(const QuicPacketHeader* header, NetLog::LogLevel /* log_level */) { base::DictionaryValue* dict = new base::DictionaryValue(); @@ -86,7 +101,7 @@ base::Value* NetLogQuicAckFrameCallback(const QuicAckFrame* frame, frame->received_info.missing_packets; for (SequenceNumberSet::const_iterator it = missing_packets.begin(); it != missing_packets.end(); ++it) { - missing->Append(new base::StringValue(base::Uint64ToString(*it))); + missing->AppendString(base::Uint64ToString(*it)); } return dict; } @@ -107,7 +122,7 @@ base::Value* NetLogQuicCongestionFeedbackFrameCallback( it != frame->inter_arrival.received_packet_times.end(); ++it) { std::string value = base::Uint64ToString(it->first) + "@" + base::Uint64ToString(it->second.ToDebuggingValue()); - received->Append(new base::StringValue(value)); + received->AppendString(value); } break; } @@ -146,6 +161,37 @@ base::Value* NetLogQuicConnectionCloseFrameCallback( return dict; } +base::Value* NetLogQuicVersionNegotiationPacketCallback( + const QuicVersionNegotiationPacket* packet, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + base::ListValue* versions = new base::ListValue(); + dict->Set("versions", versions); + for (QuicVersionVector::const_iterator it = packet->versions.begin(); + it != packet->versions.end(); ++it) { + versions->AppendString(QuicVersionToString(*it)); + } + return dict; +} + +base::Value* NetLogQuicCryptoHandshakeMessageCallback( + const CryptoHandshakeMessage* message, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetString("quic_crypto_handshake_message", message->DebugString()); + return dict; +} + +base::Value* NetLogQuicConnectionClosedCallback( + QuicErrorCode error, + bool from_peer, + NetLog::LogLevel /* log_level */) { + base::DictionaryValue* dict = new base::DictionaryValue(); + dict->SetInteger("quic_error", error); + dict->SetBoolean("from_peer", from_peer); + return dict; +} + void UpdatePacketGapSentHistogram(size_t num_consecutive_missing_packets) { UMA_HISTOGRAM_COUNTS("Net.QuicSession.PacketGapSent", num_consecutive_missing_packets); @@ -216,6 +262,15 @@ void QuicConnectionLogger::OnPacketSent( packet.length(), rv)); } +void QuicConnectionLogger:: OnPacketRetransmitted( + QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number) { + net_log_.AddEvent( + NetLog::TYPE_QUIC_SESSION_PACKET_RETRANSMITTED, + base::Bind(&NetLogQuicPacketRetransmittedCallback, + old_sequence_number, new_sequence_number)); +} + void QuicConnectionLogger::OnPacketReceived(const IPEndPoint& self_address, const IPEndPoint& peer_address, const QuicEncryptedPacket& packet) { @@ -248,6 +303,9 @@ void QuicConnectionLogger::OnPacketHeader(const QuicPacketHeader& header) { } if (header.packet_sequence_number < last_received_packet_sequence_number_) { ++out_of_order_recieved_packet_count_; + UMA_HISTOGRAM_COUNTS("Net.QuicSession.OutOfOrderGapReceived", + last_received_packet_sequence_number_ - + header.packet_sequence_number); } last_received_packet_sequence_number_ = header.packet_sequence_number; } @@ -321,15 +379,50 @@ void QuicConnectionLogger::OnConnectionCloseFrame( void QuicConnectionLogger::OnPublicResetPacket( const QuicPublicResetPacket& packet) { + net_log_.AddEvent(NetLog::TYPE_QUIC_SESSION_PUBLIC_RESET_PACKET_RECEIVED); } void QuicConnectionLogger::OnVersionNegotiationPacket( const QuicVersionNegotiationPacket& packet) { + net_log_.AddEvent( + NetLog::TYPE_QUIC_SESSION_VERSION_NEGOTIATION_PACKET_RECEIVED, + base::Bind(&NetLogQuicVersionNegotiationPacketCallback, &packet)); } void QuicConnectionLogger::OnRevivedPacket( const QuicPacketHeader& revived_header, base::StringPiece payload) { + net_log_.AddEvent( + NetLog::TYPE_QUIC_SESSION_PACKET_HEADER_REVIVED, + base::Bind(&NetLogQuicPacketHeaderCallback, &revived_header)); +} + +void QuicConnectionLogger::OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message) { + net_log_.AddEvent( + NetLog::TYPE_QUIC_SESSION_CRYPTO_HANDSHAKE_MESSAGE_RECEIVED, + base::Bind(&NetLogQuicCryptoHandshakeMessageCallback, &message)); +} + +void QuicConnectionLogger::OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message) { + net_log_.AddEvent( + NetLog::TYPE_QUIC_SESSION_CRYPTO_HANDSHAKE_MESSAGE_SENT, + base::Bind(&NetLogQuicCryptoHandshakeMessageCallback, &message)); +} + +void QuicConnectionLogger::OnConnectionClose(QuicErrorCode error, + bool from_peer) { + net_log_.AddEvent( + NetLog::TYPE_QUIC_SESSION_CLOSED, + base::Bind(&NetLogQuicConnectionClosedCallback, error, from_peer)); +} + +void QuicConnectionLogger::OnSuccessfulVersionNegotiation( + const QuicVersion& version) { + string quic_version = QuicVersionToString(version); + net_log_.AddEvent(NetLog::TYPE_QUIC_SESSION_VERSION_NEGOTIATED, + NetLog::StringCallback("version", &quic_version)); } } // namespace net diff --git a/chromium/net/quic/quic_connection_logger.h b/chromium/net/quic/quic_connection_logger.h index 1d2bd2df96d..d498b128bd5 100644 --- a/chromium/net/quic/quic_connection_logger.h +++ b/chromium/net/quic/quic_connection_logger.h @@ -10,6 +10,7 @@ namespace net { class BoundNetLog; +class CryptoHandshakeMessage; // This class is a debug visitor of a QuicConnection which logs // events to |net_log|. @@ -28,7 +29,9 @@ class NET_EXPORT_PRIVATE QuicConnectionLogger EncryptionLevel level, const QuicEncryptedPacket& packet, int rv) OVERRIDE; - + virtual void OnPacketRetransmitted( + QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number) OVERRIDE; virtual void OnPacketReceived(const IPEndPoint& self_address, const IPEndPoint& peer_address, const QuicEncryptedPacket& packet) OVERRIDE; @@ -48,6 +51,13 @@ class NET_EXPORT_PRIVATE QuicConnectionLogger virtual void OnRevivedPacket(const QuicPacketHeader& revived_header, base::StringPiece payload) OVERRIDE; + void OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message); + void OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message); + void OnConnectionClose(QuicErrorCode error, bool from_peer); + void OnSuccessfulVersionNegotiation(const QuicVersion& version); + private: BoundNetLog net_log_; // The last packet sequence number received. diff --git a/chromium/net/quic/quic_stats.cc b/chromium/net/quic/quic_connection_stats.cc index 7404d927843..f66a5cba799 100644 --- a/chromium/net/quic/quic_stats.cc +++ b/chromium/net/quic/quic_connection_stats.cc @@ -1,8 +1,8 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/quic/quic_stats.h" +#include "net/quic/quic_connection_stats.h" namespace net { diff --git a/chromium/net/quic/quic_stats.h b/chromium/net/quic/quic_connection_stats.h index 252791e80de..f9336621f99 100644 --- a/chromium/net/quic/quic_stats.h +++ b/chromium/net/quic/quic_connection_stats.h @@ -1,9 +1,9 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef NET_QUIC_QUIC_STATS_H_ -#define NET_QUIC_QUIC_STATS_H_ +#ifndef NET_QUIC_QUIC_CONNECTION_STATS_H_ +#define NET_QUIC_QUIC_CONNECTION_STATS_H_ #include "base/basictypes.h" #include "net/base/net_export.h" @@ -47,4 +47,4 @@ struct NET_EXPORT_PRIVATE QuicConnectionStats { } // namespace net -#endif // NET_QUIC_QUIC_STATS_H_ +#endif // NET_QUIC_QUIC_CONNECTION_STATS_H_ diff --git a/chromium/net/quic/quic_connection_test.cc b/chromium/net/quic/quic_connection_test.cc index bd698c3eb1d..afe8a7663b8 100644 --- a/chromium/net/quic/quic_connection_test.cc +++ b/chromium/net/quic/quic_connection_test.cc @@ -29,14 +29,14 @@ using std::map; using std::vector; using testing::_; using testing::AnyNumber; -using testing::Between; using testing::ContainerEq; using testing::DoAll; using testing::InSequence; using testing::InvokeWithoutArgs; +using testing::Ref; using testing::Return; -using testing::StrictMock; using testing::SaveArg; +using testing::StrictMock; namespace net { namespace test { @@ -50,6 +50,13 @@ const bool kEntropyFlag = true; const QuicPacketEntropyHash kTestEntropyHash = 76; +const int kDefaultRetransmissionTimeMs = 500; + +// Used by TestConnection::SendStreamData3. +const QuicStreamId kStreamId3 = 3; +// Used by TestConnection::SendStreamData5. +const QuicStreamId kStreamId5 = 5; + class TestReceiveAlgorithm : public ReceiveAlgorithmInterface { public: explicit TestReceiveAlgorithm(QuicCongestionFeedbackFrame* feedback) @@ -382,12 +389,59 @@ class TestConnection : public QuicConnection { QuicConnectionPeer::SetSendAlgorithm(this, send_algorithm); } - QuicConsumedData SendStreamData1() { - return SendStreamData(1u, "food", 0, !kFin); + bool SendOrQueuePacket(EncryptionLevel level, + QuicPacketSequenceNumber sequence_number, + QuicPacket* packet, + QuicPacketEntropyHash entropy_hash, + HasRetransmittableData retransmittable) { + return SendOrQueuePacket(level, + sequence_number, + packet, + entropy_hash, + retransmittable, + NO_FORCE); + } + + QuicConsumedData SendStreamData(QuicStreamId id, + StringPiece data, + QuicStreamOffset offset, + bool fin) { + struct iovec iov = {const_cast<char*>(data.data()), + static_cast<size_t>(data.size())}; + return SendvStreamData(id, &iov, 1, offset, fin); + } + + QuicConsumedData SendStreamDataAndNotifyWhenAcked( + QuicStreamId id, + StringPiece data, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier::DelegateInterface* delegate) { + struct iovec iov = {const_cast<char*>(data.data()), + static_cast<size_t>(data.size())}; + return SendvStreamDataAndNotifyWhenAcked(id, &iov, 1, offset, fin, + delegate); } - QuicConsumedData SendStreamData2() { - return SendStreamData(2u, "food2", 0, !kFin); + QuicConsumedData SendStreamData3() { + return SendStreamData(kStreamId3, "food", 0, !kFin); + } + + QuicConsumedData SendStreamData5() { + return SendStreamData(kStreamId5, "food2", 0, !kFin); + } + + // The crypto stream has special semantics so that it is not blocked by a + // congestion window limitation, and also so that it gets put into a separate + // packet (so that it is easier to reason about a crypto frame not being + // split needlessly across packet boundaries). As a result, we have separate + // tests for some cases for this stream. + QuicConsumedData SendCryptoStreamData() { + this->Flush(); + QuicConsumedData consumed = + SendStreamData(kCryptoStreamId, "chlo", 0, !kFin); + this->Flush(); + return consumed; } bool is_server() { @@ -451,9 +505,23 @@ class QuicConnectionTest : public ::testing::Test { QuicTime::Delta::Zero())); EXPECT_CALL(*receive_algorithm_, RecordIncomingPacket(_, _, _, _)).Times(AnyNumber()); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(AnyNumber()); EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillRepeatedly( Return(QuicTime::Delta::Zero())); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillRepeatedly(Return( + QuicBandwidth::FromKBitsPerSecond(100))); + EXPECT_CALL(*send_algorithm_, SmoothedRtt()).WillRepeatedly(Return( + QuicTime::Delta::FromMilliseconds(100))); + ON_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)) + .WillByDefault(Return(true)); + // TODO(rch): remove this. + QuicConnection::g_acks_do_not_instigate_acks = true; + EXPECT_CALL(visitor_, HasPendingHandshake()).Times(AnyNumber()); + } + + ~QuicConnectionTest() { + // TODO(rch): remove this. + QuicConnection::g_acks_do_not_instigate_acks = false; } QuicAckFrame* outgoing_ack() { @@ -486,8 +554,7 @@ class QuicConnectionTest : public ::testing::Test { } void ProcessPacket(QuicPacketSequenceNumber number) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)) - .WillOnce(Return(accept_packet_)); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillOnce(Return(accept_packet_)); ProcessDataPacket(number, 0, !kEntropyFlag); } @@ -505,18 +572,6 @@ class QuicConnectionTest : public ::testing::Test { return serialized_packet.entropy_hash; } - size_t ProcessFecProtectedPacket(QuicPacketSequenceNumber number, - bool expect_revival) { - if (expect_revival) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).Times(2).WillRepeatedly( - Return(accept_packet_)); - } else { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillOnce( - Return(accept_packet_)); - } - return ProcessDataPacket(number, 1, !kEntropyFlag); - } - size_t ProcessDataPacket(QuicPacketSequenceNumber number, QuicFecGroupNumber fec_group, bool entropy_flag) { @@ -547,28 +602,33 @@ class QuicConnectionTest : public ::testing::Test { size_t ProcessFecProtectedPacket(QuicPacketSequenceNumber number, bool expect_revival, bool entropy_flag) { if (expect_revival) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillOnce(DoAll( - SaveArg<2>(&revived_header_), Return(accept_packet_))); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillOnce(Return(accept_packet_)); } - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillOnce(Return(accept_packet_)) + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillOnce(Return(accept_packet_)) .RetiresOnSaturation(); return ProcessDataPacket(number, 1, entropy_flag); } - // Sends an FEC packet that covers the packets that would have been sent. + // Processes an FEC packet that covers the packets that would have been + // received. size_t ProcessFecPacket(QuicPacketSequenceNumber number, QuicPacketSequenceNumber min_protected_packet, bool expect_revival, - bool entropy_flag) { + bool entropy_flag, + QuicPacket* packet) { if (expect_revival) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillOnce(DoAll( - SaveArg<2>(&revived_header_), Return(accept_packet_))); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillOnce(Return(accept_packet_)); } // Construct the decrypted data packet so we can compute the correct - // redundancy. - scoped_ptr<QuicPacket> data_packet(ConstructDataPacket(number, 1, - !kEntropyFlag)); + // redundancy. If |packet| has been provided then use that, otherwise + // construct a default data packet. + scoped_ptr<QuicPacket> data_packet; + if (packet) { + data_packet.reset(packet); + } else { + data_packet.reset(ConstructDataPacket(number, 1, !kEntropyFlag)); + } header_.public_header.guid = guid_; header_.public_header.reset_flag = false; @@ -580,6 +640,7 @@ class QuicConnectionTest : public ::testing::Test { header_.fec_group = min_protected_packet; QuicFecData fec_data; fec_data.fec_group = header_.fec_group; + // Since all data packets in this test have the same payload, the // redundancy is either equal to that payload or the xor of that payload // with itself, depending on the number of packets. @@ -593,6 +654,7 @@ class QuicConnectionTest : public ::testing::Test { } } fec_data.redundancy = data_packet->FecProtectedData(); + scoped_ptr<QuicPacket> fec_packet( framer_.BuildFecPacket(header_, fec_data).packet); scoped_ptr<QuicEncryptedPacket> encrypted( @@ -606,21 +668,21 @@ class QuicConnectionTest : public ::testing::Test { QuicStreamOffset offset, bool fin, QuicPacketSequenceNumber* last_packet) { QuicByteCount packet_size; - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).WillOnce( - SaveArg<2>(&packet_size)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)) + .WillOnce(DoAll(SaveArg<2>(&packet_size), Return(true))); connection_.SendStreamData(id, data, offset, fin); if (last_packet != NULL) { *last_packet = QuicConnectionPeer::GetPacketCreator(&connection_)->sequence_number(); } - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(AnyNumber()); return packet_size; } void SendAckPacketToPeer() { - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); connection_.SendAck(); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(AnyNumber()); } QuicPacketEntropyHash ProcessAckPacket(QuicAckFrame* frame, @@ -689,6 +751,10 @@ class QuicConnectionTest : public ::testing::Test { connection_.SetReceiveAlgorithm(receive_algorithm_); } + QuicTime::Delta DefaultRetransmissionTime() { + return QuicTime::Delta::FromMilliseconds(kDefaultRetransmissionTimeMs); + } + QuicGuid guid_; QuicFramer framer_; QuicPacketCreator creator_; @@ -699,10 +765,9 @@ class QuicConnectionTest : public ::testing::Test { MockRandom random_generator_; TestConnectionHelper* helper_; TestConnection connection_; - testing::StrictMock<MockConnectionVisitor> visitor_; + StrictMock<MockConnectionVisitor> visitor_; QuicPacketHeader header_; - QuicPacketHeader revived_header_; QuicStreamFrame frame1_; QuicStreamFrame frame2_; scoped_ptr<QuicAckFrame> outgoing_ack_; @@ -713,6 +778,8 @@ class QuicConnectionTest : public ::testing::Test { }; TEST_F(QuicConnectionTest, PacketsInOrder) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); EXPECT_EQ(1u, outgoing_ack()->received_info.largest_observed); EXPECT_EQ(0u, outgoing_ack()->received_info.missing_packets.size()); @@ -727,6 +794,8 @@ TEST_F(QuicConnectionTest, PacketsInOrder) { } TEST_F(QuicConnectionTest, PacketsRejected) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); EXPECT_EQ(1u, outgoing_ack()->received_info.largest_observed); EXPECT_EQ(0u, outgoing_ack()->received_info.missing_packets.size()); @@ -739,6 +808,8 @@ TEST_F(QuicConnectionTest, PacketsRejected) { } TEST_F(QuicConnectionTest, PacketsOutOfOrder) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(3); EXPECT_EQ(3u, outgoing_ack()->received_info.largest_observed); EXPECT_TRUE(IsMissing(2)); @@ -756,13 +827,15 @@ TEST_F(QuicConnectionTest, PacketsOutOfOrder) { } TEST_F(QuicConnectionTest, DuplicatePacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(3); EXPECT_EQ(3u, outgoing_ack()->received_info.largest_observed); EXPECT_TRUE(IsMissing(2)); EXPECT_TRUE(IsMissing(1)); // Send packet 3 again, but do not set the expectation that - // the visitor OnPacket() will be called. + // the visitor OnStreamFrames() will be called. ProcessDataPacket(3, 0, !kEntropyFlag); EXPECT_EQ(3u, outgoing_ack()->received_info.largest_observed); EXPECT_TRUE(IsMissing(2)); @@ -770,6 +843,8 @@ TEST_F(QuicConnectionTest, DuplicatePacket) { } TEST_F(QuicConnectionTest, PacketsOutOfOrderWithAdditionsAndLeastAwaiting) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(3); EXPECT_EQ(3u, outgoing_ack()->received_info.largest_observed); EXPECT_TRUE(IsMissing(2)); @@ -805,7 +880,7 @@ TEST_F(QuicConnectionTest, RejectPacketTooFarOut) { } TEST_F(QuicConnectionTest, TruncatedAck) { - EXPECT_CALL(visitor_, OnAck(_)).Times(testing::AnyNumber()); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(2); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); for (int i = 0; i < 200; ++i) { @@ -813,26 +888,28 @@ TEST_F(QuicConnectionTest, TruncatedAck) { } QuicAckFrame frame(0, QuicTime::Zero(), 1); - frame.received_info.largest_observed = 192; - InsertMissingPacketsBetween(&frame.received_info, 1, 192); + frame.received_info.largest_observed = 193; + InsertMissingPacketsBetween(&frame.received_info, 1, 193); frame.received_info.entropy_hash = - QuicConnectionPeer::GetSentEntropyHash(&connection_, 192) ^ - QuicConnectionPeer::GetSentEntropyHash(&connection_, 191); + QuicConnectionPeer::GetSentEntropyHash(&connection_, 193) ^ + QuicConnectionPeer::GetSentEntropyHash(&connection_, 192); ProcessAckPacket(&frame, true); EXPECT_TRUE(QuicConnectionPeer::GetReceivedTruncatedAck(&connection_)); - frame.received_info.missing_packets.erase(191); + frame.received_info.missing_packets.erase(192); frame.received_info.entropy_hash = - QuicConnectionPeer::GetSentEntropyHash(&connection_, 192) ^ - QuicConnectionPeer::GetSentEntropyHash(&connection_, 190); + QuicConnectionPeer::GetSentEntropyHash(&connection_, 193) ^ + QuicConnectionPeer::GetSentEntropyHash(&connection_, 191); ProcessAckPacket(&frame, true); EXPECT_FALSE(QuicConnectionPeer::GetReceivedTruncatedAck(&connection_)); } TEST_F(QuicConnectionTest, AckReceiptCausesAckSendBadEntropy) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessPacket(1); // Delay sending, then queue up an ack. EXPECT_CALL(*send_algorithm_, @@ -852,11 +929,13 @@ TEST_F(QuicConnectionTest, AckReceiptCausesAckSendBadEntropy) { } TEST_F(QuicConnectionTest, AckReceiptCausesAckSend) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); QuicPacketSequenceNumber largest_observed; QuicByteCount packet_size; - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION)) - .WillOnce(DoAll(SaveArg<1>(&largest_observed), SaveArg<2>(&packet_size))); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)) + .WillOnce(DoAll(SaveArg<1>(&largest_observed), SaveArg<2>(&packet_size), + Return(true))); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, _)).Times(1); connection_.SendStreamData(1, "foo", 0, !kFin); QuicAckFrame frame(1, QuicTime::Zero(), largest_observed); @@ -867,13 +946,13 @@ TEST_F(QuicConnectionTest, AckReceiptCausesAckSend) { ProcessAckPacket(&frame, true); // Third nack should retransmit the largest observed packet. EXPECT_CALL(*send_algorithm_, SentPacket(_, _, packet_size - kQuicVersionSize, - IS_RETRANSMISSION)); + IS_RETRANSMISSION, _)); ProcessAckPacket(&frame, true); // Now if the peer sends an ack which still reports the retransmitted packet // as missing, then that will count as a packet which instigates an ack. ProcessAckPacket(&frame, true); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)); ProcessAckPacket(&frame, true); // But an ack with no new missing packest will not send an ack. @@ -883,6 +962,8 @@ TEST_F(QuicConnectionTest, AckReceiptCausesAckSend) { } TEST_F(QuicConnectionTest, LeastUnackedLower) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + SendStreamDataToPeer(1, "foo", 0, !kFin, NULL); SendStreamDataToPeer(1, "bar", 3, !kFin, NULL); SendStreamDataToPeer(1, "eep", 6, !kFin, NULL); @@ -902,12 +983,14 @@ TEST_F(QuicConnectionTest, LeastUnackedLower) { // Now claim it's one, but set the ordering so it was sent "after" the first // one. This should cause a connection error. EXPECT_CALL(visitor_, ConnectionClose(QUIC_INVALID_ACK_DATA, false)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); creator_.set_sequence_number(7); ProcessAckPacket(&frame2, false); } TEST_F(QuicConnectionTest, LargestObservedLower) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + SendStreamDataToPeer(1, "foo", 0, !kFin, NULL); SendStreamDataToPeer(1, "bar", 3, !kFin, NULL); SendStreamDataToPeer(1, "eep", 6, !kFin, NULL); @@ -917,7 +1000,6 @@ TEST_F(QuicConnectionTest, LargestObservedLower) { QuicAckFrame frame(2, QuicTime::Zero(), 0); frame.received_info.entropy_hash = QuicConnectionPeer::GetSentEntropyHash( &connection_, 2); - EXPECT_CALL(visitor_, OnAck(_)); ProcessAckPacket(&frame, true); // Now change it to 1, and it should cause a connection error. @@ -927,8 +1009,9 @@ TEST_F(QuicConnectionTest, LargestObservedLower) { } TEST_F(QuicConnectionTest, LeastUnackedGreaterThanPacketSequenceNumber) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(visitor_, ConnectionClose(QUIC_INVALID_ACK_DATA, false)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); // Create an ack with least_unacked is 2 in packet number 1. creator_.set_sequence_number(0); QuicAckFrame frame(0, QuicTime::Zero(), 2); @@ -937,12 +1020,14 @@ TEST_F(QuicConnectionTest, LeastUnackedGreaterThanPacketSequenceNumber) { TEST_F(QuicConnectionTest, NackSequenceNumberGreaterThanLargestReceived) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + SendStreamDataToPeer(1, "foo", 0, !kFin, NULL); SendStreamDataToPeer(1, "bar", 3, !kFin, NULL); SendStreamDataToPeer(1, "eep", 6, !kFin, NULL); EXPECT_CALL(visitor_, ConnectionClose(QUIC_INVALID_ACK_DATA, false)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); QuicAckFrame frame(0, QuicTime::Zero(), 1); frame.received_info.missing_packets.insert(3); ProcessAckPacket(&frame, false); @@ -951,12 +1036,14 @@ TEST_F(QuicConnectionTest, TEST_F(QuicConnectionTest, AckUnsentData) { // Ack a packet which has not been sent. EXPECT_CALL(visitor_, ConnectionClose(QUIC_INVALID_ACK_DATA, false)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); QuicAckFrame frame(1, QuicTime::Zero(), 0); ProcessAckPacket(&frame, false); } TEST_F(QuicConnectionTest, AckAll) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessPacket(1); creator_.set_sequence_number(1); @@ -964,7 +1051,108 @@ TEST_F(QuicConnectionTest, AckAll) { ProcessAckPacket(&frame1, true); } +TEST_F(QuicConnectionTest, SendingDifferentSequenceNumberLengthsBandwidth) { + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillOnce(Return( + QuicBandwidth::FromKBitsPerSecond(1000))); + + QuicPacketSequenceNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, !kFin, &last_packet); + EXPECT_EQ(1u, last_packet); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillOnce(Return( + QuicBandwidth::FromKBitsPerSecond(1000 * 256))); + + SendStreamDataToPeer(1u, "bar", 3, !kFin, &last_packet); + EXPECT_EQ(2u, last_packet); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + // The 1 packet lag is due to the sequence number length being recalculated in + // QuicConnection after a packet is sent. + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillOnce(Return( + QuicBandwidth::FromKBitsPerSecond(1000 * 256 * 256))); + + SendStreamDataToPeer(1, "foo", 6, !kFin, &last_packet); + EXPECT_EQ(3u, last_packet); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillOnce(Return( + QuicBandwidth::FromKBitsPerSecond(1000ll * 256 * 256 * 256))); + + SendStreamDataToPeer(1u, "bar", 9, !kFin, &last_packet); + EXPECT_EQ(4u, last_packet); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillOnce(Return( + QuicBandwidth::FromKBitsPerSecond(1000ll * 256 * 256 * 256 * 256))); + + SendStreamDataToPeer(1u, "foo", 12, !kFin, &last_packet); + EXPECT_EQ(5u, last_packet); + EXPECT_EQ(PACKET_6BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); +} + +TEST_F(QuicConnectionTest, SendingDifferentSequenceNumberLengthsUnackedDelta) { + QuicPacketSequenceNumber last_packet; + SendStreamDataToPeer(1, "foo", 0, !kFin, &last_packet); + EXPECT_EQ(1u, last_packet); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + QuicConnectionPeer::GetPacketCreator(&connection_)->set_sequence_number(100); + + SendStreamDataToPeer(1u, "bar", 3, !kFin, &last_packet); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + QuicConnectionPeer::GetPacketCreator(&connection_)->set_sequence_number( + 100 * 256); + + SendStreamDataToPeer(1, "foo", 6, !kFin, &last_packet); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + QuicConnectionPeer::GetPacketCreator(&connection_)->set_sequence_number( + 100 * 256 * 256); + + SendStreamDataToPeer(1u, "bar", 9, !kFin, &last_packet); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); + + QuicConnectionPeer::GetPacketCreator(&connection_)->set_sequence_number( + 100 * 256 * 256 * 256); + + SendStreamDataToPeer(1u, "foo", 12, !kFin, &last_packet); + EXPECT_EQ(PACKET_6BYTE_SEQUENCE_NUMBER, + connection_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + last_header()->public_header.sequence_number_length); +} + TEST_F(QuicConnectionTest, BasicSending) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(6); QuicPacketSequenceNumber last_packet; SendStreamDataToPeer(1, "foo", 0, !kFin, &last_packet); // Packet 1 @@ -981,11 +1169,7 @@ TEST_F(QuicConnectionTest, BasicSending) { SendAckPacketToPeer(); // Packet 5 EXPECT_EQ(1u, last_ack()->sent_info.least_unacked); - SequenceNumberSet expected_acks; - expected_acks.insert(1); - // Peer acks up to packet 3. - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); QuicAckFrame frame(3, QuicTime::Zero(), 0); frame.received_info.entropy_hash = QuicConnectionPeer::GetSentEntropyHash(&connection_, 3); @@ -996,11 +1180,7 @@ TEST_F(QuicConnectionTest, BasicSending) { // ack for 4. EXPECT_EQ(4u, last_ack()->sent_info.least_unacked); - expected_acks.clear(); - expected_acks.insert(4); - // Peer acks up to packet 4, the last packet. - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); QuicAckFrame frame2(6, QuicTime::Zero(), 0); frame2.received_info.entropy_hash = QuicConnectionPeer::GetSentEntropyHash(&connection_, 6); @@ -1027,13 +1207,14 @@ TEST_F(QuicConnectionTest, FECSending) { // All packets carry version info till version is negotiated. size_t payload_length; connection_.options()->max_packet_length = - GetPacketLengthForOneStream(connection_.version(), kIncludeVersion, - IN_FEC_GROUP, &payload_length); + GetPacketLengthForOneStream( + connection_.version(), kIncludeVersion, PACKET_1BYTE_SEQUENCE_NUMBER, + IN_FEC_GROUP, &payload_length); // And send FEC every two packets. connection_.options()->max_packets_per_fec_group = 2; // Send 4 data packets and 2 FEC packets. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(6); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(6); // The first stream frame will consume 2 fewer bytes than the other three. const string payload(payload_length * 4 - 6, 'a'); connection_.SendStreamData(1, payload, 0, !kFin); @@ -1045,8 +1226,9 @@ TEST_F(QuicConnectionTest, FECQueueing) { // All packets carry version info till version is negotiated. size_t payload_length; connection_.options()->max_packet_length = - GetPacketLengthForOneStream(connection_.version(), kIncludeVersion, - IN_FEC_GROUP, &payload_length); + GetPacketLengthForOneStream( + connection_.version(), kIncludeVersion, PACKET_1BYTE_SEQUENCE_NUMBER, + IN_FEC_GROUP, &payload_length); // And send FEC every two packets. connection_.options()->max_packets_per_fec_group = 2; @@ -1062,7 +1244,7 @@ TEST_F(QuicConnectionTest, FECQueueing) { TEST_F(QuicConnectionTest, AbandonFECFromCongestionWindow) { connection_.options()->max_packets_per_fec_group = 1; // 1 Data and 1 FEC packet. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(2); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(2); connection_.SendStreamData(1, "foo", 0, !kFin); // Larger timeout for FEC bytes to expire. @@ -1071,7 +1253,7 @@ TEST_F(QuicConnectionTest, AbandonFECFromCongestionWindow) { clock_.AdvanceTime(retransmission_time); // Send only data packet. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); // Abandon both FEC and data packet. EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(2); @@ -1079,12 +1261,13 @@ TEST_F(QuicConnectionTest, AbandonFECFromCongestionWindow) { } TEST_F(QuicConnectionTest, DontAbandonAckedFEC) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); connection_.options()->max_packets_per_fec_group = 1; const QuicPacketSequenceNumber sequence_number = QuicConnectionPeer::GetPacketCreator(&connection_)->sequence_number() + 1; // 1 Data and 1 FEC packet. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(2); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(2); connection_.SendStreamData(1, "foo", 0, !kFin); QuicAckFrame ack_fec(2, QuicTime::Zero(), 1); @@ -1094,20 +1277,17 @@ TEST_F(QuicConnectionTest, DontAbandonAckedFEC) { QuicConnectionPeer::GetSentEntropyHash(&connection_, 2) ^ QuicConnectionPeer::GetSentEntropyHash(&connection_, 1); - EXPECT_CALL(visitor_, OnAck(_)).Times(1); EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(1); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); ProcessAckPacket(&ack_fec, true); - const QuicTime::Delta kDefaultRetransmissionTime = - QuicTime::Delta::FromMilliseconds(5000); - clock_.AdvanceTime(kDefaultRetransmissionTime); + clock_.AdvanceTime(DefaultRetransmissionTime()); // Abandon only data packet, FEC has been acked. EXPECT_CALL(*send_algorithm_, AbandoningPacket(sequence_number, _)).Times(1); // Send only data packet. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); connection_.OnRetransmissionTimeout(); } @@ -1120,15 +1300,15 @@ TEST_F(QuicConnectionTest, FramePacking) { connection_.SendAck(); EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(DoAll( IgnoreResult(InvokeWithoutArgs(&connection_, - &TestConnection::SendStreamData1)), + &TestConnection::SendStreamData3)), IgnoreResult(InvokeWithoutArgs(&connection_, - &TestConnection::SendStreamData2)), + &TestConnection::SendStreamData5)), Return(true))); // Unblock the connection. connection_.GetSendAlarm()->Cancel(); EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, NOT_RETRANSMISSION)) + SentPacket(_, _, _, NOT_RETRANSMISSION, _)) .Times(1); connection_.OnCanWrite(); EXPECT_EQ(0u, connection_.NumQueuedPackets()); @@ -1139,8 +1319,70 @@ TEST_F(QuicConnectionTest, FramePacking) { EXPECT_EQ(3u, helper_->frame_count()); EXPECT_TRUE(helper_->ack()); EXPECT_EQ(2u, helper_->stream_frames()->size()); - EXPECT_EQ(1u, (*helper_->stream_frames())[0].stream_id); - EXPECT_EQ(2u, (*helper_->stream_frames())[1].stream_id); + EXPECT_EQ(kStreamId3, (*helper_->stream_frames())[0].stream_id); + EXPECT_EQ(kStreamId5, (*helper_->stream_frames())[1].stream_id); +} + +TEST_F(QuicConnectionTest, FramePackingNonCryptoThenCrypto) { + // Block the connection. + connection_.GetSendAlarm()->Set( + clock_.ApproximateNow().Add(QuicTime::Delta::FromSeconds(1))); + + // Send an ack and two stream frames (one non-crypto, then one crypto) in 2 + // packets by queueing them. + connection_.SendAck(); + EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(DoAll( + IgnoreResult(InvokeWithoutArgs(&connection_, + &TestConnection::SendStreamData3)), + IgnoreResult(InvokeWithoutArgs(&connection_, + &TestConnection::SendCryptoStreamData)), + Return(true))); + + // Unblock the connection. + connection_.GetSendAlarm()->Cancel(); + EXPECT_CALL(*send_algorithm_, + SentPacket(_, _, _, NOT_RETRANSMISSION, _)) + .Times(2); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's the crypto stream frame. + EXPECT_EQ(1u, helper_->frame_count()); + EXPECT_TRUE(helper_->ack()); + EXPECT_EQ(1u, helper_->stream_frames()->size()); + EXPECT_EQ(kCryptoStreamId, (*helper_->stream_frames())[0].stream_id); +} + +TEST_F(QuicConnectionTest, FramePackingCryptoThenNonCrypto) { + // Block the connection. + connection_.GetSendAlarm()->Set( + clock_.ApproximateNow().Add(QuicTime::Delta::FromSeconds(1))); + + // Send an ack and two stream frames (one crypto, then one non-crypto) in 3 + // packets by queueing them. + connection_.SendAck(); + EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(DoAll( + IgnoreResult(InvokeWithoutArgs(&connection_, + &TestConnection::SendCryptoStreamData)), + IgnoreResult(InvokeWithoutArgs(&connection_, + &TestConnection::SendStreamData3)), + Return(true))); + + // Unblock the connection. + connection_.GetSendAlarm()->Cancel(); + EXPECT_CALL(*send_algorithm_, + SentPacket(_, _, _, NOT_RETRANSMISSION, _)) + .Times(3); + connection_.OnCanWrite(); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's the stream frame from stream 3. + EXPECT_EQ(1u, helper_->frame_count()); + EXPECT_TRUE(helper_->ack()); + EXPECT_EQ(1u, helper_->stream_frames()->size()); + EXPECT_EQ(kStreamId3, (*helper_->stream_frames())[0].stream_id); } TEST_F(QuicConnectionTest, FramePackingFEC) { @@ -1154,15 +1396,15 @@ TEST_F(QuicConnectionTest, FramePackingFEC) { connection_.SendAck(); EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(DoAll( IgnoreResult(InvokeWithoutArgs(&connection_, - &TestConnection::SendStreamData1)), + &TestConnection::SendStreamData3)), IgnoreResult(InvokeWithoutArgs(&connection_, - &TestConnection::SendStreamData2)), + &TestConnection::SendStreamData5)), Return(true))); // Unblock the connection. connection_.GetSendAlarm()->Cancel(); EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, NOT_RETRANSMISSION)).Times(2); + SentPacket(_, _, _, NOT_RETRANSMISSION, _)).Times(2); connection_.OnCanWrite(); EXPECT_EQ(0u, connection_.NumQueuedPackets()); EXPECT_FALSE(connection_.HasQueuedData()); @@ -1172,13 +1414,83 @@ TEST_F(QuicConnectionTest, FramePackingFEC) { EXPECT_EQ(0u, helper_->frame_count()); } +TEST_F(QuicConnectionTest, FramePackingSendv) { + // Send two stream frames in 1 packet by using writev. + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)); + + char data[] = "ABCD"; + iovec iov[2] = { {static_cast<void*>(data), 2}, + {static_cast<void*>(data + 2), 2} }; + connection_.SendvStreamData(1, iov, 2, 0, !kFin); + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's two stream frames from one stream. + // TODO(ianswett): Ideally this would arrive in one frame in the future. + EXPECT_EQ(2u, helper_->frame_count()); + EXPECT_EQ(2u, helper_->stream_frames()->size()); + EXPECT_EQ(1u, (*helper_->stream_frames())[0].stream_id); + EXPECT_EQ(1u, (*helper_->stream_frames())[1].stream_id); +} + +TEST_F(QuicConnectionTest, FramePackingSendvQueued) { + // Try to send two stream frames in 1 packet by using writev. + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)); + + helper_->set_blocked(true); + char data[] = "ABCD"; + iovec iov[2] = { {static_cast<void*>(data), 2}, + {static_cast<void*>(data + 2), 2} }; + connection_.SendvStreamData(1, iov, 2, 0, !kFin); + + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + EXPECT_TRUE(connection_.HasQueuedData()); + + // Attempt to send all packets, but since we're actually still + // blocked, they should all remain queued. + EXPECT_FALSE(connection_.OnCanWrite()); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + + // Unblock the writes and actually send. + helper_->set_blocked(false); + EXPECT_CALL(visitor_, OnCanWrite()); + EXPECT_TRUE(connection_.OnCanWrite()); + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + + // Parse the last packet and ensure it's two stream frames from one stream. + // TODO(ianswett): Ideally this would arrive in one frame in the future. + EXPECT_EQ(2u, helper_->frame_count()); + EXPECT_EQ(2u, helper_->stream_frames()->size()); + EXPECT_EQ(1u, (*helper_->stream_frames())[0].stream_id); + EXPECT_EQ(1u, (*helper_->stream_frames())[1].stream_id); +} + +TEST_F(QuicConnectionTest, SendingZeroBytes) { + // Send a zero byte write with a fin using writev. + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)); + + iovec iov[1]; + connection_.SendvStreamData(1, iov, 0, 0, kFin); + + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + EXPECT_FALSE(connection_.HasQueuedData()); + + // Parse the last packet and ensure it's two stream frames from one stream. + // TODO(ianswett): Ideally this would arrive in one frame in the future. + EXPECT_EQ(1u, helper_->frame_count()); + EXPECT_EQ(1u, helper_->stream_frames()->size()); + EXPECT_EQ(1u, (*helper_->stream_frames())[0].stream_id); + EXPECT_TRUE((*helper_->stream_frames())[0].fin); +} + TEST_F(QuicConnectionTest, OnCanWrite) { // Visitor's OnCanWill send data, but will return false. EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(DoAll( IgnoreResult(InvokeWithoutArgs(&connection_, - &TestConnection::SendStreamData1)), + &TestConnection::SendStreamData3)), IgnoreResult(InvokeWithoutArgs(&connection_, - &TestConnection::SendStreamData2)), + &TestConnection::SendStreamData5)), Return(false))); EXPECT_CALL(*send_algorithm_, @@ -1191,8 +1503,8 @@ TEST_F(QuicConnectionTest, OnCanWrite) { // two different streams. EXPECT_EQ(2u, helper_->frame_count()); EXPECT_EQ(2u, helper_->stream_frames()->size()); - EXPECT_EQ(1u, (*helper_->stream_frames())[0].stream_id); - EXPECT_EQ(2u, (*helper_->stream_frames())[1].stream_id); + EXPECT_EQ(kStreamId3, (*helper_->stream_frames())[0].stream_id); + EXPECT_EQ(kStreamId5, (*helper_->stream_frames())[1].stream_id); } TEST_F(QuicConnectionTest, RetransmitOnNack) { @@ -1206,9 +1518,7 @@ TEST_F(QuicConnectionTest, RetransmitOnNack) { SendStreamDataToPeer(1, "foos", 3, !kFin, &last_packet); // Packet 2 SendStreamDataToPeer(1, "fooos", 7, !kFin, &last_packet); // Packet 3 - SequenceNumberSet expected_acks; - expected_acks.insert(1); - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); // Peer acks one but not two or three. Right now we only retransmit on // explicit nack, so it should not trigger a retransimission. @@ -1219,10 +1529,6 @@ TEST_F(QuicConnectionTest, RetransmitOnNack) { ProcessAckPacket(&ack_one, true); ProcessAckPacket(&ack_one, true); - expected_acks.clear(); - expected_acks.insert(3); - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); - // Peer acks up to 3 with two explicitly missing. Two nacks should cause no // change. QuicAckFrame nack_two(3, QuicTime::Zero(), 0); @@ -1237,16 +1543,18 @@ TEST_F(QuicConnectionTest, RetransmitOnNack) { // The third nack should trigger a retransimission. EXPECT_CALL(*send_algorithm_, SentPacket(_, _, second_packet_size - kQuicVersionSize, - IS_RETRANSMISSION)).Times(1); + IS_RETRANSMISSION, _)).Times(1); ProcessAckPacket(&nack_two, true); } TEST_F(QuicConnectionTest, RetransmitNackedLargestObserved) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); QuicPacketSequenceNumber largest_observed; QuicByteCount packet_size; - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION)) - .WillOnce(DoAll(SaveArg<1>(&largest_observed), SaveArg<2>(&packet_size))); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)) + .WillOnce(DoAll(SaveArg<1>(&largest_observed), SaveArg<2>(&packet_size), + Return(true))); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, _)).Times(1); connection_.SendStreamData(1, "foo", 0, !kFin); QuicAckFrame frame(1, QuicTime::Zero(), largest_observed); @@ -1257,50 +1565,48 @@ TEST_F(QuicConnectionTest, RetransmitNackedLargestObserved) { ProcessAckPacket(&frame, true); // Third nack should retransmit the largest observed packet. EXPECT_CALL(*send_algorithm_, SentPacket(_, _, packet_size - kQuicVersionSize, - IS_RETRANSMISSION)); + IS_RETRANSMISSION, _)); ProcessAckPacket(&frame, true); } TEST_F(QuicConnectionTest, RetransmitNackedPacketsOnTruncatedAck) { for (int i = 0; i < 200; ++i) { - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); connection_.SendStreamData(1, "foo", i * 3, !kFin); } // Make a truncated ack frame. QuicAckFrame frame(0, QuicTime::Zero(), 1); - frame.received_info.largest_observed = 192; - InsertMissingPacketsBetween(&frame.received_info, 1, 192); + frame.received_info.largest_observed = 193; + InsertMissingPacketsBetween(&frame.received_info, 1, 193); frame.received_info.entropy_hash = - QuicConnectionPeer::GetSentEntropyHash(&connection_, 192) ^ - QuicConnectionPeer::GetSentEntropyHash(&connection_, 191); - + QuicConnectionPeer::GetSentEntropyHash(&connection_, 193) ^ + QuicConnectionPeer::GetSentEntropyHash(&connection_, 192); EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(1); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); - EXPECT_CALL(visitor_, OnAck(_)).Times(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessAckPacket(&frame, true); EXPECT_TRUE(QuicConnectionPeer::GetReceivedTruncatedAck(&connection_)); QuicConnectionPeer::SetMaxPacketsPerRetransmissionAlarm(&connection_, 200); - const QuicTime::Delta kDefaultRetransmissionTime = - QuicTime::Delta::FromMilliseconds(500); - clock_.AdvanceTime(kDefaultRetransmissionTime); + clock_.AdvanceTime(DefaultRetransmissionTime()); // Only packets that are less than largest observed should be retransmitted. - EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(191); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(191); + EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(192); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(192); connection_.OnRetransmissionTimeout(); clock_.AdvanceTime(QuicTime::Delta::FromMicroseconds( - 2 * kDefaultRetransmissionTime.ToMicroseconds())); + 2 * DefaultRetransmissionTime().ToMicroseconds())); // Retransmit already retransmitted packets event though the sequence number // greater than the largest observed. - EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(191); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(191); + EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(192); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(192); connection_.OnRetransmissionTimeout(); } TEST_F(QuicConnectionTest, LimitPacketsPerNack) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnIncomingAck(12, _, _)).Times(1); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(11); @@ -1320,19 +1626,16 @@ TEST_F(QuicConnectionTest, LimitPacketsPerNack) { nack.received_info.entropy_hash = QuicConnectionPeer::GetSentEntropyHash(&connection_, 12) ^ QuicConnectionPeer::GetSentEntropyHash(&connection_, 11); - SequenceNumberSet expected_acks; - expected_acks.insert(12); - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); // Nack three times. ProcessAckPacket(&nack, true); ProcessAckPacket(&nack, true); // The third call should trigger retransmitting 10 packets. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(10); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(10); ProcessAckPacket(&nack, true); // The fourth call should trigger retransmitting the 11th packet. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); ProcessAckPacket(&nack, true); } @@ -1361,26 +1664,15 @@ TEST_F(QuicConnectionTest, MultipleAcks) { QuicConnectionPeer::GetSentEntropyHash(&connection_, 2) ^ QuicConnectionPeer::GetSentEntropyHash(&connection_, 1); - // The connection should pass up acks for 1, 4, 5. 2 is not acked, and 3 was - // an ackframe so should not be passed up. - SequenceNumberSet expected_acks; - expected_acks.insert(1); - expected_acks.insert(4); - expected_acks.insert(5); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); ProcessAckPacket(&frame1, true); // Now the client implicitly acks 2, and explicitly acks 6 QuicAckFrame frame2(6, QuicTime::Zero(), 0); frame2.received_info.entropy_hash = QuicConnectionPeer::GetSentEntropyHash(&connection_, 6); - expected_acks.clear(); - // Both acks should be passed up. - expected_acks.insert(2); - expected_acks.insert(6); - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); ProcessAckPacket(&frame2, true); } @@ -1389,12 +1681,7 @@ TEST_F(QuicConnectionTest, DontLatchUnackedPacket) { SendStreamDataToPeer(1, "foo", 0, !kFin, NULL); // Packet 1; SendAckPacketToPeer(); // Packet 2 - // This sets least unacked to 3 (unsent packet), since we don't need - // an ack for Packet 2 (ack packet). - SequenceNumberSet expected_acks; - expected_acks.insert(1); - // Peer acks packet 1. - EXPECT_CALL(visitor_, OnAck(ContainerEq(expected_acks))); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); QuicAckFrame frame(1, QuicTime::Zero(), 0); frame.received_info.entropy_hash = QuicConnectionPeer::GetSentEntropyHash( &connection_, 1); @@ -1419,50 +1706,62 @@ TEST_F(QuicConnectionTest, DontLatchUnackedPacket) { } TEST_F(QuicConnectionTest, ReviveMissingPacketAfterFecPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + // Don't send missing packet 1. - ProcessFecPacket(2, 1, true, !kEntropyFlag); - EXPECT_FALSE(revived_header_.entropy_flag); + ProcessFecPacket(2, 1, true, !kEntropyFlag, NULL); + // Entropy flag should be false, so entropy should be 0. + EXPECT_EQ(0u, QuicConnectionPeer::ReceivedEntropyHash(&connection_, 2)); } TEST_F(QuicConnectionTest, ReviveMissingPacketAfterDataPacketThenFecPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessFecProtectedPacket(1, false, kEntropyFlag); // Don't send missing packet 2. - ProcessFecPacket(3, 1, true, !kEntropyFlag); - EXPECT_TRUE(revived_header_.entropy_flag); + ProcessFecPacket(3, 1, true, !kEntropyFlag, NULL); + // Entropy flag should be true, so entropy should not be 0. + EXPECT_NE(0u, QuicConnectionPeer::ReceivedEntropyHash(&connection_, 2)); } TEST_F(QuicConnectionTest, ReviveMissingPacketAfterDataPacketsThenFecPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessFecProtectedPacket(1, false, !kEntropyFlag); // Don't send missing packet 2. ProcessFecProtectedPacket(3, false, !kEntropyFlag); - ProcessFecPacket(4, 1, true, kEntropyFlag); - EXPECT_TRUE(revived_header_.entropy_flag); + ProcessFecPacket(4, 1, true, kEntropyFlag, NULL); + // Entropy flag should be true, so entropy should not be 0. + EXPECT_NE(0u, QuicConnectionPeer::ReceivedEntropyHash(&connection_, 2)); } TEST_F(QuicConnectionTest, ReviveMissingPacketAfterDataPacket) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + // Don't send missing packet 1. - ProcessFecPacket(3, 1, false, !kEntropyFlag); + ProcessFecPacket(3, 1, false, !kEntropyFlag, NULL); // out of order ProcessFecProtectedPacket(2, true, !kEntropyFlag); - EXPECT_FALSE(revived_header_.entropy_flag); + // Entropy flag should be false, so entropy should be 0. + EXPECT_EQ(0u, QuicConnectionPeer::ReceivedEntropyHash(&connection_, 2)); } TEST_F(QuicConnectionTest, ReviveMissingPacketAfterDataPackets) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + ProcessFecProtectedPacket(1, false, !kEntropyFlag); // Don't send missing packet 2. - ProcessFecPacket(6, 1, false, kEntropyFlag); + ProcessFecPacket(6, 1, false, kEntropyFlag, NULL); ProcessFecProtectedPacket(3, false, kEntropyFlag); ProcessFecProtectedPacket(4, false, kEntropyFlag); ProcessFecProtectedPacket(5, true, !kEntropyFlag); - EXPECT_TRUE(revived_header_.entropy_flag); + // Entropy flag should be true, so entropy should be 0. + EXPECT_NE(0u, QuicConnectionPeer::ReceivedEntropyHash(&connection_, 2)); } TEST_F(QuicConnectionTest, TestRetransmit) { - const QuicTime::Delta kDefaultRetransmissionTime = - QuicTime::Delta::FromMilliseconds(500); - QuicTime default_retransmission_time = clock_.ApproximateNow().Add( - kDefaultRetransmissionTime); + DefaultRetransmissionTime()); SendStreamDataToPeer(1, "foo", 0, !kFin, NULL); EXPECT_EQ(1u, outgoing_ack()->sent_info.least_unacked); @@ -1470,8 +1769,8 @@ TEST_F(QuicConnectionTest, TestRetransmit) { EXPECT_EQ(default_retransmission_time, connection_.GetRetransmissionAlarm()->deadline()); // Simulate the retransimission alarm firing - clock_.AdvanceTime(kDefaultRetransmissionTime); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + clock_.AdvanceTime(DefaultRetransmissionTime()); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, _)).Times(1); connection_.RetransmitPacket(1); EXPECT_EQ(2u, last_header()->packet_sequence_number); @@ -1479,11 +1778,8 @@ TEST_F(QuicConnectionTest, TestRetransmit) { } TEST_F(QuicConnectionTest, RetransmitWithSameEncryptionLevel) { - const QuicTime::Delta kDefaultRetransmissionTime = - QuicTime::Delta::FromMilliseconds(500); - QuicTime default_retransmission_time = clock_.ApproximateNow().Add( - kDefaultRetransmissionTime); + DefaultRetransmissionTime()); use_tagging_decrypter(); // A TaggingEncrypter puts kTagSize copies of the given byte (0x01 here) at @@ -1500,15 +1796,15 @@ TEST_F(QuicConnectionTest, RetransmitWithSameEncryptionLevel) { EXPECT_EQ(default_retransmission_time, connection_.GetRetransmissionAlarm()->deadline()); // Simulate the retransimission alarm firing - clock_.AdvanceTime(kDefaultRetransmissionTime); + clock_.AdvanceTime(DefaultRetransmissionTime()); EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(2); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); connection_.RetransmitPacket(1); // Packet should have been sent with ENCRYPTION_NONE. EXPECT_EQ(0x01010101u, final_bytes_of_last_packet()); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); connection_.RetransmitPacket(2); // Packet should have been sent with ENCRYPTION_INITIAL. EXPECT_EQ(0x02020202u, final_bytes_of_last_packet()); @@ -1525,18 +1821,16 @@ TEST_F(QuicConnectionTest, new TaggingEncrypter(0x02)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(0); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(0); EXPECT_CALL(*send_algorithm_, AbandoningPacket(sequence_number, _)).Times(1); - const QuicTime::Delta kDefaultRetransmissionTime = - QuicTime::Delta::FromMilliseconds(500); QuicTime default_retransmission_time = clock_.ApproximateNow().Add( - kDefaultRetransmissionTime); + DefaultRetransmissionTime()); EXPECT_EQ(default_retransmission_time, connection_.GetRetransmissionAlarm()->deadline()); // Simulate the retransimission alarm firing - clock_.AdvanceTime(kDefaultRetransmissionTime); + clock_.AdvanceTime(DefaultRetransmissionTime()); connection_.OnRetransmissionTimeout(); } @@ -1552,13 +1846,14 @@ TEST_F(QuicConnectionTest, RetransmitPacketsWithInitialEncryption) { SendStreamDataToPeer(2, "bar", 0, !kFin, NULL); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(1); connection_.RetransmitUnackedPackets(QuicConnection::INITIAL_ENCRYPTION_ONLY); } TEST_F(QuicConnectionTest, BufferNonDecryptablePackets) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); use_tagging_decrypter(); const uint8 tag = 0x07; @@ -1574,26 +1869,26 @@ TEST_F(QuicConnectionTest, BufferNonDecryptablePackets) { connection_.SetDecrypter(new StrictTaggingDecrypter(tag)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); connection_.SetEncrypter(ENCRYPTION_INITIAL, new TaggingEncrypter(tag)); - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).Times(2).WillRepeatedly( + EXPECT_CALL(visitor_, OnStreamFrames(_)).Times(2).WillRepeatedly( Return(true)); ProcessDataPacketAtLevel(2, false, kEntropyFlag, ENCRYPTION_INITIAL); // Finally, process a third packet and note that we do not // reprocess the buffered packet. - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillOnce(Return(true)); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillOnce(Return(true)); ProcessDataPacketAtLevel(3, false, kEntropyFlag, ENCRYPTION_INITIAL); } TEST_F(QuicConnectionTest, TestRetransmitOrder) { QuicByteCount first_packet_size; - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).WillOnce( - SaveArg<2>(&first_packet_size)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).WillOnce( + DoAll(SaveArg<2>(&first_packet_size), Return(true))); EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(2); connection_.SendStreamData(1, "first_packet", 0, !kFin); QuicByteCount second_packet_size; - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).WillOnce( - SaveArg<2>(&second_packet_size)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).WillOnce( + DoAll(SaveArg<2>(&second_packet_size), Return(true))); connection_.SendStreamData(1, "second_packet", 12, !kFin); EXPECT_NE(first_packet_size, second_packet_size); // Advance the clock by huge time to make sure packets will be retransmitted. @@ -1601,20 +1896,20 @@ TEST_F(QuicConnectionTest, TestRetransmitOrder) { { InSequence s; EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, first_packet_size, _)); + SentPacket(_, _, first_packet_size, _, _)); EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, second_packet_size, _)); + SentPacket(_, _, second_packet_size, _, _)); } connection_.OnRetransmissionTimeout(); } TEST_F(QuicConnectionTest, TestRetransmissionCountCalculation) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(2); QuicPacketSequenceNumber original_sequence_number; - EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, NOT_RETRANSMISSION)) - .WillOnce(SaveArg<1>(&original_sequence_number)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)) + .WillOnce(DoAll(SaveArg<1>(&original_sequence_number), Return(true))); connection_.SendStreamData(1, "foo", 0, !kFin); EXPECT_TRUE(QuicConnectionPeer::IsSavedForRetransmission( &connection_, original_sequence_number)); @@ -1623,9 +1918,8 @@ TEST_F(QuicConnectionTest, TestRetransmissionCountCalculation) { // Force retransmission due to RTO. clock_.AdvanceTime(QuicTime::Delta::FromSeconds(10)); QuicPacketSequenceNumber rto_sequence_number; - EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, IS_RETRANSMISSION)) - .WillOnce(SaveArg<1>(&rto_sequence_number)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, IS_RETRANSMISSION, _)) + .WillOnce(DoAll(SaveArg<1>(&rto_sequence_number), Return(true))); connection_.OnRetransmissionTimeout(); EXPECT_FALSE(QuicConnectionPeer::IsSavedForRetransmission( &connection_, original_sequence_number)); @@ -1637,10 +1931,10 @@ TEST_F(QuicConnectionTest, TestRetransmissionCountCalculation) { QuicPacketSequenceNumber nack_sequence_number; // Ack packets might generate some other packets, which are not // retransmissions. (More ack packets). - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION)) + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)) .Times(AnyNumber()); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, IS_RETRANSMISSION)) - .WillOnce(SaveArg<1>(&nack_sequence_number)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, IS_RETRANSMISSION, _)) + .WillOnce(DoAll(SaveArg<1>(&nack_sequence_number), Return(true))); QuicAckFrame ack(rto_sequence_number, QuicTime::Zero(), 0); // Ack the retransmitted packet. ack.received_info.missing_packets.insert(rto_sequence_number); @@ -1670,6 +1964,45 @@ TEST_F(QuicConnectionTest, SetRTOAfterWritingToSocket) { EXPECT_EQ(1u, QuicConnectionPeer::GetNumRetransmissionTimeouts(&connection_)); } +TEST_F(QuicConnectionTest, DelayRTOWithAckReceipt) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, NOT_RETRANSMISSION, _)) + .Times(2); + connection_.SendStreamData(1, "foo", 0, !kFin); + connection_.SendStreamData(2, "bar", 0, !kFin); + EXPECT_EQ(2u, QuicConnectionPeer::GetNumRetransmissionTimeouts(&connection_)); + + // Advance the time right before the RTO, then receive an ack for the first + // packet to delay the RTO. + clock_.AdvanceTime(DefaultRetransmissionTime()); + EXPECT_EQ(2u, QuicConnectionPeer::GetNumRetransmissionTimeouts(&connection_)); + EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(1); + QuicAckFrame ack(1, QuicTime::Zero(), 0); + ProcessAckPacket(&ack, true); + EXPECT_EQ(1u, QuicConnectionPeer::GetNumRetransmissionTimeouts(&connection_)); + + // Move forward past the original RTO and ensure the RTO is still pending. + clock_.AdvanceTime(DefaultRetransmissionTime()); + EXPECT_EQ(1u, QuicConnectionPeer::GetNumRetransmissionTimeouts(&connection_)); + + // Ensure the second packet gets retransmitted when it finally fires. + EXPECT_TRUE( + QuicConnectionPeer::GetRetransmissionAlarm(&connection_)->IsSet()); + EXPECT_GE( + QuicConnectionPeer::GetRetransmissionAlarm(&connection_)->deadline(), + clock_.ApproximateNow()); + clock_.AdvanceTime(DefaultRetransmissionTime()); + EXPECT_LT( + QuicConnectionPeer::GetRetransmissionAlarm(&connection_)->deadline(), + clock_.ApproximateNow()); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, IS_RETRANSMISSION, _)); + EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)); + connection_.OnRetransmissionTimeout(); + + // The new retransmitted sequence number should now be in the timeout queue. + EXPECT_EQ(1u, QuicConnectionPeer::GetNumRetransmissionTimeouts(&connection_)); +} + TEST_F(QuicConnectionTest, TestQueued) { EXPECT_EQ(0u, connection_.NumQueuedPackets()); helper_->set_blocked(true); @@ -1689,6 +2022,7 @@ TEST_F(QuicConnectionTest, TestQueued) { } TEST_F(QuicConnectionTest, CloseFecGroup) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); // Don't send missing packet 1 // Don't send missing packet 2 ProcessFecProtectedPacket(3, false, !kEntropyFlag); @@ -1721,21 +2055,23 @@ TEST_F(QuicConnectionTest, WithQuicCongestionFeedbackFrame) { TEST_F(QuicConnectionTest, UpdateQuicCongestionFeedbackFrame) { SendAckPacketToPeer(); EXPECT_CALL(*receive_algorithm_, RecordIncomingPacket(_, _, _, _)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessPacket(1); } TEST_F(QuicConnectionTest, DontUpdateQuicCongestionFeedbackFrameForRevived) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); SendAckPacketToPeer(); // Process an FEC packet, and revive the missing data packet // but only contact the receive_algorithm once. EXPECT_CALL(*receive_algorithm_, RecordIncomingPacket(_, _, _, _)); - ProcessFecPacket(2, 1, true, !kEntropyFlag); + ProcessFecPacket(2, 1, true, !kEntropyFlag, NULL); } TEST_F(QuicConnectionTest, InitialTimeout) { EXPECT_TRUE(connection_.connected()); EXPECT_CALL(visitor_, ConnectionClose(QUIC_CONNECTION_TIMED_OUT, false)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); QuicTime default_timeout = clock_.ApproximateNow().Add( QuicTime::Delta::FromSeconds(kDefaultInitialTimeoutSecs)); @@ -1774,7 +2110,7 @@ TEST_F(QuicConnectionTest, TimeoutAfterSend) { // This time, we should time out. EXPECT_CALL(visitor_, ConnectionClose(QUIC_CONNECTION_TIMED_OUT, false)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); EXPECT_EQ(default_timeout.Add(QuicTime::Delta::FromMilliseconds(5)), clock_.ApproximateNow()); @@ -1789,7 +2125,7 @@ TEST_F(QuicConnectionTest, SendScheduler) { EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( testing::Return(QuicTime::Delta::Zero())); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); connection_.SendOrQueuePacket( ENCRYPTION_NONE, 1, packet, kTestEntropyHash, HAS_RETRANSMITTABLE_DATA); EXPECT_EQ(0u, connection_.NumQueuedPackets()); @@ -1801,7 +2137,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelay) { EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( testing::Return(QuicTime::Delta::FromMicroseconds(1))); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, _)).Times(0); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, _, _)).Times(0); connection_.SendOrQueuePacket( ENCRYPTION_NONE, 1, packet, kTestEntropyHash, HAS_RETRANSMITTABLE_DATA); EXPECT_EQ(1u, connection_.NumQueuedPackets()); @@ -1812,7 +2148,7 @@ TEST_F(QuicConnectionTest, SendSchedulerForce) { QuicPacket* packet = ConstructDataPacket(1, 0, !kEntropyFlag); EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, IS_RETRANSMISSION, _, _)).Times(0); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); connection_.SendOrQueuePacket( ENCRYPTION_NONE, 1, packet, kTestEntropyHash, HAS_RETRANSMITTABLE_DATA); // XXX: fixme. was: connection_.SendOrQueuePacket(1, packet, kForce); @@ -1825,7 +2161,7 @@ TEST_F(QuicConnectionTest, SendSchedulerEAGAIN) { EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( testing::Return(QuicTime::Delta::Zero())); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, _)).Times(0); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, _, _)).Times(0); connection_.SendOrQueuePacket( ENCRYPTION_NONE, 1, packet, kTestEntropyHash, HAS_RETRANSMITTABLE_DATA); EXPECT_EQ(1u, connection_.NumQueuedPackets()); @@ -1848,7 +2184,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelayThenSend) { testing::Return(QuicTime::Delta::Zero())); clock_.AdvanceTime(QuicTime::Delta::FromMicroseconds(1)); connection_.GetSendAlarm()->Cancel(); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)); EXPECT_CALL(visitor_, OnCanWrite()); connection_.OnCanWrite(); EXPECT_EQ(0u, connection_.NumQueuedPackets()); @@ -1859,7 +2195,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelayThenRetransmit) { .WillRepeatedly(testing::Return(QuicTime::Delta::Zero())); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, _)).Times(1); EXPECT_CALL(*send_algorithm_, - SentPacket(_, 1, _, NOT_RETRANSMISSION)); + SentPacket(_, 1, _, NOT_RETRANSMISSION, _)); connection_.SendStreamData(1, "foo", 0, !kFin); EXPECT_EQ(0u, connection_.NumQueuedPackets()); // Advance the time for retransmission of lost packet. @@ -1879,7 +2215,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelayThenRetransmit) { // Ensure the scheduler is notified this is a retransmit. EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, IS_RETRANSMISSION)); + SentPacket(_, _, _, IS_RETRANSMISSION, _)); clock_.AdvanceTime(QuicTime::Delta::FromMicroseconds(1)); connection_.GetSendAlarm()->Cancel(); EXPECT_CALL(visitor_, OnCanWrite()); @@ -1904,6 +2240,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelayAndQueue) { } TEST_F(QuicConnectionTest, SendSchedulerDelayThenAckAndSend) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); QuicPacket* packet = ConstructDataPacket(1, 0, !kEntropyFlag); EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( @@ -1919,7 +2256,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelayThenAckAndSend) { TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillRepeatedly( testing::Return(QuicTime::Delta::Zero())); EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, _)); + SentPacket(_, _, _, _, _)); ProcessAckPacket(&frame, true); EXPECT_EQ(0u, connection_.NumQueuedPackets()); @@ -1928,6 +2265,7 @@ TEST_F(QuicConnectionTest, SendSchedulerDelayThenAckAndSend) { } TEST_F(QuicConnectionTest, SendSchedulerDelayThenAckAndHold) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); QuicPacket* packet = ConstructDataPacket(1, 0, !kEntropyFlag); EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( @@ -1966,8 +2304,9 @@ TEST_F(QuicConnectionTest, TestQueueLimitsOnSendStreamData) { // All packets carry version info till version is negotiated. size_t payload_length; connection_.options()->max_packet_length = - GetPacketLengthForOneStream(connection_.version(), kIncludeVersion, - NOT_IN_FEC_GROUP, &payload_length); + GetPacketLengthForOneStream( + connection_.version(), kIncludeVersion, PACKET_1BYTE_SEQUENCE_NUMBER, + NOT_IN_FEC_GROUP, &payload_length); // Queue the first packet. EXPECT_CALL(*send_algorithm_, @@ -1983,11 +2322,12 @@ TEST_F(QuicConnectionTest, LoopThroughSendingPackets) { // All packets carry version info till version is negotiated. size_t payload_length; connection_.options()->max_packet_length = - GetPacketLengthForOneStream(connection_.version(), kIncludeVersion, - NOT_IN_FEC_GROUP, &payload_length); + GetPacketLengthForOneStream( + connection_.version(), kIncludeVersion, PACKET_1BYTE_SEQUENCE_NUMBER, + NOT_IN_FEC_GROUP, &payload_length); // Queue the first packet. - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(7); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(7); // The first stream frame will consume 2 fewer bytes than the other six. const string payload(payload_length * 7 - 12, 'a'); EXPECT_EQ(payload.size(), @@ -1995,10 +2335,11 @@ TEST_F(QuicConnectionTest, LoopThroughSendingPackets) { } TEST_F(QuicConnectionTest, NoAckForClose) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessPacket(1); EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(0); EXPECT_CALL(visitor_, ConnectionClose(QUIC_PEER_GOING_AWAY, true)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _)).Times(0); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(0); ProcessClosePacket(2, 0); } @@ -2008,7 +2349,7 @@ TEST_F(QuicConnectionTest, SendWhenDisconnected) { connection_.CloseConnection(QUIC_PEER_GOING_AWAY, false); EXPECT_FALSE(connection_.connected()); QuicPacket* packet = ConstructDataPacket(1, 0, !kEntropyFlag); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, _)).Times(0); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, _, _)).Times(0); connection_.SendOrQueuePacket( ENCRYPTION_NONE, 1, packet, kTestEntropyHash, HAS_RETRANSMITTABLE_DATA); } @@ -2026,6 +2367,8 @@ TEST_F(QuicConnectionTest, PublicReset) { } TEST_F(QuicConnectionTest, GoAway) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + QuicGoAwayFrame goaway; goaway.last_good_stream_id = 1; goaway.error_code = QUIC_PEER_GOING_AWAY; @@ -2038,12 +2381,14 @@ TEST_F(QuicConnectionTest, MissingPacketsBeforeLeastUnacked) { QuicAckFrame ack(0, QuicTime::Zero(), 4); // Set the sequence number of the ack packet to be least unacked (4) creator_.set_sequence_number(3); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessAckPacket(&ack, true); EXPECT_TRUE(outgoing_ack()->received_info.missing_packets.empty()); } TEST_F(QuicConnectionTest, ReceivedEntropyHashCalculation) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessDataPacket(1, 1, kEntropyFlag); ProcessDataPacket(4, 1, kEntropyFlag); ProcessDataPacket(3, 1, !kEntropyFlag); @@ -2052,7 +2397,8 @@ TEST_F(QuicConnectionTest, ReceivedEntropyHashCalculation) { } TEST_F(QuicConnectionTest, UpdateEntropyForReceivedPackets) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessDataPacket(1, 1, kEntropyFlag); ProcessDataPacket(5, 1, kEntropyFlag); ProcessDataPacket(4, 1, !kEntropyFlag); @@ -2065,14 +2411,15 @@ TEST_F(QuicConnectionTest, UpdateEntropyForReceivedPackets) { QuicPacketEntropyHash six_packet_entropy_hash = 0; if (ProcessAckPacket(&ack, true)) { six_packet_entropy_hash = 1 << 6; - }; + } EXPECT_EQ((kRandomEntropyHash + (1 << 5) + six_packet_entropy_hash), outgoing_ack()->received_info.entropy_hash); } TEST_F(QuicConnectionTest, UpdateEntropyHashUptoCurrentPacket) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessDataPacket(1, 1, kEntropyFlag); ProcessDataPacket(5, 1, !kEntropyFlag); ProcessDataPacket(22, 1, kEntropyFlag); @@ -2091,7 +2438,8 @@ TEST_F(QuicConnectionTest, UpdateEntropyHashUptoCurrentPacket) { } TEST_F(QuicConnectionTest, EntropyCalculationForTruncatedAck) { - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnStreamFrames(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); QuicPacketEntropyHash entropy[51]; entropy[0] = 0; for (int i = 1; i < 51; ++i) { @@ -2215,7 +2563,8 @@ TEST_F(QuicConnectionTest, ClientHandlesVersionNegotiation) { scoped_ptr<QuicPacket> packet( framer_.BuildUnsizedDataPacket(header, frames).packet); encrypted.reset(framer_.EncryptPacket(ENCRYPTION_NONE, 12, *packet)); - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).Times(1); + EXPECT_CALL(visitor_, OnStreamFrames(_)).Times(1); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); connection_.ProcessUdpPacket(IPEndPoint(), IPEndPoint(), *encrypted); ASSERT_FALSE(QuicPacketCreatorPeer::SendVersionInPacket( @@ -2250,18 +2599,18 @@ TEST_F(QuicConnectionTest, BadVersionNegotiation) { TEST_F(QuicConnectionTest, CheckSendStats) { EXPECT_CALL(*send_algorithm_, AbandoningPacket(_, _)).Times(3); EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, NOT_RETRANSMISSION)); + SentPacket(_, _, _, NOT_RETRANSMISSION, _)); connection_.SendStreamData(1u, "first", 0, !kFin); size_t first_packet_size = last_sent_packet_size(); EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, NOT_RETRANSMISSION)); + SentPacket(_, _, _, NOT_RETRANSMISSION, _)); connection_.SendStreamData(1u, "second", 0, !kFin); size_t second_packet_size = last_sent_packet_size(); // 2 retransmissions due to rto, 1 due to explicit nack. EXPECT_CALL(*send_algorithm_, - SentPacket(_, _, _, IS_RETRANSMISSION)).Times(3); + SentPacket(_, _, _, IS_RETRANSMISSION, _)).Times(3); // Retransmit due to RTO. clock_.AdvanceTime(QuicTime::Delta::FromSeconds(10)); @@ -2275,10 +2624,10 @@ TEST_F(QuicConnectionTest, CheckSendStats) { QuicConnectionPeer::GetSentEntropyHash(&connection_, 3) ^ QuicConnectionPeer::GetSentEntropyHash(&connection_, 2); QuicFrame frame(&nack_three); - EXPECT_CALL(visitor_, OnAck(_)); EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(1); EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); EXPECT_CALL(visitor_, OnCanWrite()).Times(3).WillRepeatedly(Return(true)); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); ProcessFramePacket(frame); ProcessFramePacket(frame); @@ -2300,12 +2649,14 @@ TEST_F(QuicConnectionTest, CheckSendStats) { } TEST_F(QuicConnectionTest, CheckReceiveStats) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + size_t received_bytes = 0; received_bytes += ProcessFecProtectedPacket(1, false, !kEntropyFlag); received_bytes += ProcessFecProtectedPacket(3, false, !kEntropyFlag); // Should be counted against dropped packets. received_bytes += ProcessDataPacket(3, 1, !kEntropyFlag); - received_bytes += ProcessFecPacket(4, 1, true, !kEntropyFlag); // Fec packet + received_bytes += ProcessFecPacket(4, 1, true, !kEntropyFlag, NULL); EXPECT_CALL(*send_algorithm_, SmoothedRtt()).WillOnce( Return(QuicTime::Delta::Zero())); @@ -2368,7 +2719,8 @@ TEST_F(QuicConnectionTest, DontProcessFramesIfPacketClosedConnection) { ENCRYPTION_NONE, 1, *packet)); EXPECT_CALL(visitor_, ConnectionClose(QUIC_PEER_GOING_AWAY, true)); - EXPECT_CALL(visitor_, OnPacket(_, _, _, _)).Times(0); + EXPECT_CALL(visitor_, OnStreamFrames(_)).Times(0); + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); connection_.ProcessUdpPacket(IPEndPoint(), IPEndPoint(), *encrypted); } @@ -2440,6 +2792,191 @@ TEST_F(QuicConnectionTest, ConnectionCloseWhenNothingPending) { EXPECT_EQ(1u, helper_->packets_write_attempts()); } +TEST_F(QuicConnectionTest, AckNotifierTriggerCallback) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Create a delegate which we expect to be called. + MockAckNotifierDelegate delegate; + EXPECT_CALL(delegate, OnAckNotification()).Times(1);; + + // Send some data, which will register the delegate to be notified. + connection_.SendStreamDataAndNotifyWhenAcked(1, "foo", 0, !kFin, &delegate); + + // Process an ACK from the server which should trigger the callback. + EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(1); + QuicAckFrame frame(1, QuicTime::Zero(), 0); + ProcessAckPacket(&frame, true); +} + +TEST_F(QuicConnectionTest, AckNotifierFailToTriggerCallback) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Create a delegate which we don't expect to be called. + MockAckNotifierDelegate delegate; + EXPECT_CALL(delegate, OnAckNotification()).Times(0);; + + EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(2); + EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); + + // Send some data, which will register the delegate to be notified. This will + // not be ACKed and so the delegate should never be called. + connection_.SendStreamDataAndNotifyWhenAcked(1, "foo", 0, !kFin, &delegate); + + // Send some other data which we will ACK. + connection_.SendStreamData(1, "foo", 0, !kFin); + connection_.SendStreamData(1, "bar", 0, !kFin); + + // Now we receive ACK for packets 2 and 3, but importantly missing packet 1 + // which we registered to be notified about. + QuicAckFrame frame(3, QuicTime::Zero(), 0); + frame.received_info.missing_packets.insert(1); + ProcessAckPacket(&frame, true); +} + +TEST_F(QuicConnectionTest, AckNotifierCallbackAfterRetransmission) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + + // Create a delegate which we expect to be called. + MockAckNotifierDelegate delegate; + EXPECT_CALL(delegate, OnAckNotification()).Times(1);; + + // In total expect ACKs for all 4 packets. + EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(4); + + // We will lose the second packet. + EXPECT_CALL(*send_algorithm_, OnIncomingLoss(_)).Times(1); + + // Send four packets, and register to be notified on ACK of packet 2. + connection_.SendStreamData(1, "foo", 0, !kFin); + connection_.SendStreamDataAndNotifyWhenAcked(1, "bar", 0, !kFin, &delegate); + connection_.SendStreamData(1, "baz", 0, !kFin); + connection_.SendStreamData(1, "qux", 0, !kFin); + + // Now we receive ACK for packets 1, 3, and 4. + QuicAckFrame frame(4, QuicTime::Zero(), 0); + frame.received_info.missing_packets.insert(2); + ProcessAckPacket(&frame, true); + + // Advance time to trigger RTO, for packet 2 (which should be retransmitted as + // packet 5). + EXPECT_CALL(*send_algorithm_, AbandoningPacket(2, _)).Times(1); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(1); + + clock_.AdvanceTime(DefaultRetransmissionTime()); + connection_.OnRetransmissionTimeout(); + + // Now we get an ACK for packet 5 (retransmitted packet 2), which should + // trigger the callback. + QuicAckFrame second_ack_frame(5, QuicTime::Zero(), 0); + ProcessAckPacket(&second_ack_frame, true); +} + +// TODO(rjshade): Add a similar test that FEC recovery on peer (and resulting +// ACK) triggers notification on our end. +TEST_F(QuicConnectionTest, AckNotifierCallbackAfterFECRecovery) { + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); + EXPECT_CALL(visitor_, OnCanWrite()).Times(1).WillOnce(Return(true)); + + // Create a delegate which we expect to be called. + MockAckNotifierDelegate delegate; + EXPECT_CALL(delegate, OnAckNotification()).Times(1);; + + // Expect ACKs for 1 packet. + EXPECT_CALL(*send_algorithm_, OnIncomingAck(_, _, _)).Times(1); + + // Send one packet, and register to be notified on ACK. + connection_.SendStreamDataAndNotifyWhenAcked(1, "foo", 0, !kFin, &delegate); + + // Ack packet gets dropped, but we receive an FEC packet that covers it. + // Should recover the Ack packet and trigger the notification callback. + QuicFrames frames; + + QuicAckFrame ack_frame(1, QuicTime::Zero(), 0); + frames.push_back(QuicFrame(&ack_frame)); + + // Dummy stream frame to satisfy expectations set elsewhere. + frames.push_back(QuicFrame(&frame1_)); + + QuicPacketHeader ack_header; + ack_header.public_header.guid = guid_; + ack_header.public_header.reset_flag = false; + ack_header.public_header.version_flag = false; + ack_header.entropy_flag = !kEntropyFlag; + ack_header.fec_flag = true; + ack_header.packet_sequence_number = 42; + ack_header.is_in_fec_group = IN_FEC_GROUP; + ack_header.fec_group = 1; + + QuicPacket* packet = + framer_.BuildUnsizedDataPacket(ack_header, frames).packet; + + // Take the packet which contains the ACK frame, and construct and deliver an + // FEC packet which allows the ACK packet to be recovered. + ProcessFecPacket(2, 1, true, !kEntropyFlag, packet); +} + +class MockQuicConnectionDebugVisitor + : public QuicConnectionDebugVisitorInterface { + public: + MOCK_METHOD1(OnFrameAddedToPacket, + void(const QuicFrame&)); + + MOCK_METHOD4(OnPacketSent, + void(QuicPacketSequenceNumber, + EncryptionLevel, + const QuicEncryptedPacket&, + int)); + + MOCK_METHOD2(OnPacketRetransmitted, + void(QuicPacketSequenceNumber, + QuicPacketSequenceNumber)); + + MOCK_METHOD3(OnPacketReceived, + void(const IPEndPoint&, + const IPEndPoint&, + const QuicEncryptedPacket&)); + + MOCK_METHOD1(OnProtocolVersionMismatch, + void(QuicVersion)); + + MOCK_METHOD1(OnPacketHeader, + void(const QuicPacketHeader& header)); + + MOCK_METHOD1(OnStreamFrame, + void(const QuicStreamFrame&)); + + MOCK_METHOD1(OnAckFrame, + void(const QuicAckFrame& frame)); + + MOCK_METHOD1(OnCongestionFeedbackFrame, + void(const QuicCongestionFeedbackFrame&)); + + MOCK_METHOD1(OnRstStreamFrame, + void(const QuicRstStreamFrame&)); + + MOCK_METHOD1(OnConnectionCloseFrame, + void(const QuicConnectionCloseFrame&)); + + MOCK_METHOD1(OnPublicResetPacket, + void(const QuicPublicResetPacket&)); + + MOCK_METHOD1(OnVersionNegotiationPacket, + void(const QuicVersionNegotiationPacket&)); + + MOCK_METHOD2(OnRevivedPacket, + void(const QuicPacketHeader&, StringPiece payload)); +}; + +TEST_F(QuicConnectionTest, OnPacketHeaderDebugVisitor) { + QuicPacketHeader header; + + scoped_ptr<MockQuicConnectionDebugVisitor> + debug_visitor(new StrictMock<MockQuicConnectionDebugVisitor>); + connection_.set_debug_visitor(debug_visitor.get()); + EXPECT_CALL(*debug_visitor, OnPacketHeader(Ref(header))).Times(1); + connection_.OnPacketHeader(header); +} + } // namespace } // namespace test } // namespace net diff --git a/chromium/net/quic/quic_crypto_client_stream.cc b/chromium/net/quic/quic_crypto_client_stream.cc index ab1d974b32b..b0f81174b71 100644 --- a/chromium/net/quic/quic_crypto_client_stream.cc +++ b/chromium/net/quic/quic_crypto_client_stream.cc @@ -85,6 +85,8 @@ QuicCryptoClientStream::~QuicCryptoClientStream() { void QuicCryptoClientStream::OnHandshakeMessage( const CryptoHandshakeMessage& message) { + QuicCryptoStream::OnHandshakeMessage(message); + DoHandshakeLoop(&message); } @@ -150,7 +152,7 @@ void QuicCryptoClientStream::DoHandshakeLoop( crypto_config_->LookupOrCreate(server_hostname_); if (in != NULL) { - DVLOG(1) << "Client received: " << in->DebugString(); + DVLOG(1) << "Client: Received " << in->DebugString(); } for (;;) { @@ -170,7 +172,7 @@ void QuicCryptoClientStream::DoHandshakeLoop( crypto_config_->FillInchoateClientHello( server_hostname_, cached, &crypto_negotiated_params_, &out); next_state_ = STATE_RECV_REJ; - DVLOG(1) << "Client Sending: " << out.DebugString(); + DVLOG(1) << "Client: Sending " << out.DebugString(); SendHandshakeMessage(out); return; } @@ -198,7 +200,7 @@ void QuicCryptoClientStream::DoHandshakeLoop( cert_verify_result_.reset(); } next_state_ = STATE_RECV_SHLO; - DVLOG(1) << "Client Sending: " << out.DebugString(); + DVLOG(1) << "Client: Sending " << out.DebugString(); SendHandshakeMessage(out); // Be prepared to decrypt with the new server write key. session()->connection()->SetAlternativeDecrypter( @@ -232,7 +234,7 @@ void QuicCryptoClientStream::DoHandshakeLoop( return; } error = crypto_config_->ProcessRejection( - cached, *in, session()->connection()->clock()->WallNow(), + *in, session()->connection()->clock()->WallNow(), cached, &crypto_negotiated_params_, &error_details); if (error != QUIC_NO_ERROR) { CloseConnectionWithDetails(error, error_details); @@ -262,7 +264,6 @@ void QuicCryptoClientStream::DoHandshakeLoop( verify_ok_ = false; ProofVerifier::Status status = verifier->VerifyProof( - session()->connection()->version(), server_hostname_, cached->server_config(), cached->certs(), @@ -332,8 +333,8 @@ void QuicCryptoClientStream::DoHandshakeLoop( return; } error = crypto_config_->ProcessServerHello( - *in, session()->connection()->guid(), &crypto_negotiated_params_, - &error_details); + *in, session()->connection()->guid(), cached, + &crypto_negotiated_params_, &error_details); if (error != QUIC_NO_ERROR) { CloseConnectionWithDetails( error, "Server hello invalid: " + error_details); diff --git a/chromium/net/quic/quic_crypto_client_stream_test.cc b/chromium/net/quic/quic_crypto_client_stream_test.cc index 9f9e7f75c57..2ad9a3a21c1 100644 --- a/chromium/net/quic/quic_crypto_client_stream_test.cc +++ b/chromium/net/quic/quic_crypto_client_stream_test.cc @@ -21,31 +21,6 @@ namespace { const char kServerHostname[] = "example.com"; -class TestQuicVisitor : public NoOpFramerVisitor { - public: - TestQuicVisitor() - : frame_valid_(false) { - } - - // NoOpFramerVisitor - virtual bool OnStreamFrame(const QuicStreamFrame& frame) OVERRIDE { - frame_ = frame; - frame_valid_ = true; - return true; - } - - bool frame_valid() const { - return frame_valid_; - } - QuicStreamFrame* frame() { return &frame_; } - - private: - QuicStreamFrame frame_; - bool frame_valid_; - - DISALLOW_COPY_AND_ASSIGN(TestQuicVisitor); -}; - class QuicCryptoClientStreamTest : public ::testing::Test { public: QuicCryptoClientStreamTest() diff --git a/chromium/net/quic/quic_crypto_server_stream.cc b/chromium/net/quic/quic_crypto_server_stream.cc index f7b67b7c23e..63637f54294 100644 --- a/chromium/net/quic/quic_crypto_server_stream.cc +++ b/chromium/net/quic/quic_crypto_server_stream.cc @@ -27,6 +27,8 @@ QuicCryptoServerStream::~QuicCryptoServerStream() { void QuicCryptoServerStream::OnHandshakeMessage( const CryptoHandshakeMessage& message) { + QuicCryptoStream::OnHandshakeMessage(message); + // Do not process handshake messages after the handshake is confirmed. if (handshake_confirmed_) { CloseConnection(QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE); @@ -129,7 +131,6 @@ QuicErrorCode QuicCryptoServerStream::ProcessClientHello( string* error_details) { return crypto_config_.ProcessClientHello( message, - session()->connection()->version(), session()->connection()->guid(), session()->connection()->peer_address(), session()->connection()->clock(), diff --git a/chromium/net/quic/quic_crypto_server_stream.h b/chromium/net/quic/quic_crypto_server_stream.h index f1e30cb558a..b4967d866ce 100644 --- a/chromium/net/quic/quic_crypto_server_stream.h +++ b/chromium/net/quic/quic_crypto_server_stream.h @@ -43,8 +43,6 @@ class NET_EXPORT_PRIVATE QuicCryptoServerStream : public QuicCryptoStream { CryptoHandshakeMessage* reply, std::string* error_details); - const QuicCryptoServerConfig* crypto_config() { return &crypto_config_; } - private: friend class test::CryptoTestUtils; diff --git a/chromium/net/quic/quic_crypto_server_stream_test.cc b/chromium/net/quic/quic_crypto_server_stream_test.cc index 3bb2593f1b2..9e92b2b0854 100644 --- a/chromium/net/quic/quic_crypto_server_stream_test.cc +++ b/chromium/net/quic/quic_crypto_server_stream_test.cc @@ -36,25 +36,6 @@ namespace net { namespace test { namespace { -// TODO(agl): Use rch's utility class for parsing a message when committed. -class TestQuicVisitor : public NoOpFramerVisitor { - public: - TestQuicVisitor() {} - - // NoOpFramerVisitor - virtual bool OnStreamFrame(const QuicStreamFrame& frame) OVERRIDE { - frame_ = frame; - return true; - } - - QuicStreamFrame* frame() { return &frame_; } - - private: - QuicStreamFrame frame_; - - DISALLOW_COPY_AND_ASSIGN(TestQuicVisitor); -}; - class QuicCryptoServerStreamTest : public ::testing::Test { public: QuicCryptoServerStreamTest() diff --git a/chromium/net/quic/quic_crypto_stream.cc b/chromium/net/quic/quic_crypto_stream.cc index 2f06e3ba234..3c10c5bfbf9 100644 --- a/chromium/net/quic/quic_crypto_stream.cc +++ b/chromium/net/quic/quic_crypto_stream.cc @@ -27,6 +27,11 @@ void QuicCryptoStream::OnError(CryptoFramer* framer) { session()->ConnectionClose(framer->error(), false); } +void QuicCryptoStream::OnHandshakeMessage( + const CryptoHandshakeMessage& message) { + session()->OnCryptoHandshakeMessageReceived(message); +} + uint32 QuicCryptoStream::ProcessData(const char* data, uint32 data_len) { // Do not process handshake messages after the handshake is confirmed. @@ -52,9 +57,14 @@ void QuicCryptoStream::CloseConnectionWithDetails(QuicErrorCode error, void QuicCryptoStream::SendHandshakeMessage( const CryptoHandshakeMessage& message) { + session()->OnCryptoHandshakeMessageSent(message); const QuicData& data = message.GetSerialized(); + // To make reasoning about crypto frames easier, we don't combine them with + // any other frames in a single packet. + session()->connection()->Flush(); // TODO(wtc): check the return value. WriteData(string(data.data(), data.length()), false); + session()->connection()->Flush(); } const QuicCryptoNegotiatedParameters& diff --git a/chromium/net/quic/quic_crypto_stream.h b/chromium/net/quic/quic_crypto_stream.h index bdae59e31f5..c402b0d9b44 100644 --- a/chromium/net/quic/quic_crypto_stream.h +++ b/chromium/net/quic/quic_crypto_stream.h @@ -34,7 +34,8 @@ class NET_EXPORT_PRIVATE QuicCryptoStream // CryptoFramerVisitorInterface implementation virtual void OnError(CryptoFramer* framer) OVERRIDE; - virtual void OnHandshakeMessage(const CryptoHandshakeMessage& message) = 0; + virtual void OnHandshakeMessage( + const CryptoHandshakeMessage& message) OVERRIDE; // ReliableQuicStream implementation virtual uint32 ProcessData(const char* data, uint32 data_len) OVERRIDE; diff --git a/chromium/net/quic/quic_data_writer.cc b/chromium/net/quic/quic_data_writer.cc index e52cd03248b..61e72922365 100644 --- a/chromium/net/quic/quic_data_writer.cc +++ b/chromium/net/quic/quic_data_writer.cc @@ -56,10 +56,6 @@ bool QuicDataWriter::WriteUInt64(uint64 value) { return WriteBytes(&value, sizeof(value)); } -bool QuicDataWriter::WriteUInt128(uint128 value) { - return WriteUInt64(Uint128Low64(value)) && WriteUInt64(Uint128High64(value)); -} - bool QuicDataWriter::WriteStringPiece16(StringPiece val) { if (val.length() > numeric_limits<uint16>::max()) { return false; diff --git a/chromium/net/quic/quic_data_writer.h b/chromium/net/quic/quic_data_writer.h index f3408d12215..b18121da459 100644 --- a/chromium/net/quic/quic_data_writer.h +++ b/chromium/net/quic/quic_data_writer.h @@ -43,7 +43,6 @@ class NET_EXPORT_PRIVATE QuicDataWriter { bool WriteUInt32(uint32 value); bool WriteUInt48(uint64 value); bool WriteUInt64(uint64 value); - bool WriteUInt128(uint128 value); bool WriteStringPiece16(base::StringPiece val); bool WriteBytes(const void* data, size_t data_len); bool WriteRepeatedByte(uint8 byte, size_t count); @@ -61,9 +60,6 @@ class NET_EXPORT_PRIVATE QuicDataWriter { return capacity_; } - protected: - const char* end_of_payload() const { return buffer_ + length_; } - private: // Returns the location that the data should be written at, or NULL if there // is not enough room. Call EndWrite with the returned offset and the given diff --git a/chromium/net/quic/quic_framer.cc b/chromium/net/quic/quic_framer.cc index 8796456bcdc..cfc8c95774b 100644 --- a/chromium/net/quic/quic_framer.cc +++ b/chromium/net/quic/quic_framer.cc @@ -20,6 +20,19 @@ namespace net { namespace { +// TODO(jri): Remove uses of QuicFrameTypeOld when +// QUIC versions < 10 are no longer supported. +enum QuicFrameTypeOld { + PADDING_FRAME_OLD = 0, + STREAM_FRAME_OLD, + ACK_FRAME_OLD, + CONGESTION_FEEDBACK_FRAME_OLD, + RST_STREAM_FRAME_OLD, + CONNECTION_CLOSE_FRAME_OLD, + GOAWAY_FRAME_OLD, + NUM_FRAME_TYPES_OLD +}; + // Mask to select the lowest 48 bits of a sequence number. const QuicPacketSequenceNumber k6ByteSequenceNumberMask = GG_UINT64_C(0x0000FFFFFFFFFFFF); @@ -33,8 +46,36 @@ const QuicPacketSequenceNumber k1ByteSequenceNumberMask = const QuicGuid k1ByteGuidMask = GG_UINT64_C(0x00000000000000FF); const QuicGuid k4ByteGuidMask = GG_UINT64_C(0x00000000FFFFFFFF); +// New Frame Types, QUIC v. >= 10: +// There are two interpretations for the Frame Type byte in the QUIC protocol, +// resulting in two Frame Types: Special Frame Types and Regular Frame Types. +// +// Regular Frame Types use the Frame Type byte simply. Currently defined +// Regular Frame Types are: +// Padding : 0b 00000000 (0x00) +// ResetStream : 0b 00000001 (0x01) +// ConnectionClose : 0b 00000010 (0x02) +// GoAway : 0b 00000011 (0x03) +// +// Special Frame Types encode both a Frame Type and corresponding flags +// all in the Frame Type byte. Currently defined Special Frame Types are: +// Stream : 0b 1xxxxxxx +// Ack : 0b 01xxxxxx +// CongestionFeedback : 0b 001xxxxx +// +// Semantics of the flag bits above (the x bits) depends on the frame type. + +// Masks to determine if the frame type is a special use +// and for specific special frame types. +const uint8 kQuicFrameTypeSpecialMask = 0xE0; // 0b 11100000 +const uint8 kQuicFrameTypeStreamMask = 0x80; +const uint8 kQuicFrameTypeAckMask = 0x40; +const uint8 kQuicFrameTypeCongestionFeedbackMask = 0x20; + // Mask to determine if it's a special frame type(Stream, Ack, or // Congestion Control) by checking if the first bit is 0, then shifting right. +// TODO(jri): Remove kQuicFrameType0BitMask constant from v. 10 onwards. +// Replaced by kQuicFrameTypeStream defined above. const uint8 kQuicFrameType0BitMask = 0x01; // Default frame type shift and mask. @@ -260,7 +301,8 @@ SerializedPacket QuicFramer::BuildDataPacket( const QuicFrames& frames, size_t packet_size) { QuicDataWriter writer(packet_size); - const SerializedPacket kNoPacket(0, NULL, 0, NULL); + const SerializedPacket kNoPacket( + 0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL); if (!WritePacketHeader(header, &writer)) { return kNoPacket; } @@ -331,7 +373,8 @@ SerializedPacket QuicFramer::BuildDataPacket( packet->FecProtectedData()); } - return SerializedPacket(header.packet_sequence_number, packet, + return SerializedPacket(header.packet_sequence_number, + header.public_header.sequence_number_length, packet, GetPacketEntropyHash(header), NULL); } @@ -343,7 +386,8 @@ SerializedPacket QuicFramer::BuildFecPacket(const QuicPacketHeader& header, len += fec.redundancy.length(); QuicDataWriter writer(len); - SerializedPacket kNoPacket = SerializedPacket(0, NULL, 0, NULL); + const SerializedPacket kNoPacket( + 0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL); if (!WritePacketHeader(header, &writer)) { return kNoPacket; } @@ -354,6 +398,7 @@ SerializedPacket QuicFramer::BuildFecPacket(const QuicPacketHeader& header, return SerializedPacket( header.packet_sequence_number, + header.public_header.sequence_number_length, QuicPacket::NewFecPacket(writer.take(), len, true, header.public_header.guid_length, header.public_header.version_flag, @@ -877,8 +922,8 @@ bool QuicFramer::ProcessPacketSequenceNumber( bool QuicFramer::ProcessFrameData() { if (reader_->IsDoneReading()) { - set_detailed_error("Unable to read frame type."); - return RaiseError(QUIC_INVALID_FRAME_DATA); + set_detailed_error("Packet has no frames."); + return RaiseError(QUIC_MISSING_PAYLOAD); } while (!reader_->IsDoneReading()) { uint8 frame_type; @@ -887,62 +932,16 @@ bool QuicFramer::ProcessFrameData() { return RaiseError(QUIC_INVALID_FRAME_DATA); } - if ((frame_type & kQuicFrameType0BitMask) == 0) { - QuicStreamFrame frame; - if (!ProcessStreamFrame(frame_type, &frame)) { - return RaiseError(QUIC_INVALID_FRAME_DATA); - } - if (!visitor_->OnStreamFrame(frame)) { - DLOG(INFO) << "Visitor asked to stop further processing."; - // Returning true since there was no parsing error. - return true; - } - continue; - } - - frame_type >>= 1; - if ((frame_type & kQuicFrameType0BitMask) == 0) { - QuicAckFrame frame; - if (!ProcessAckFrame(&frame)) { - return RaiseError(QUIC_INVALID_FRAME_DATA); - } - if (!visitor_->OnAckFrame(frame)) { - DLOG(INFO) << "Visitor asked to stop further processing."; - // Returning true since there was no parsing error. - return true; - } - continue; - } - - frame_type >>= 1; - if ((frame_type & kQuicFrameType0BitMask) == 0) { - QuicCongestionFeedbackFrame frame; - if (!ProcessQuicCongestionFeedbackFrame(&frame)) { - return RaiseError(QUIC_INVALID_FRAME_DATA); - } - if (!visitor_->OnCongestionFeedbackFrame(frame)) { - DLOG(INFO) << "Visitor asked to stop further processing."; - // Returning true since there was no parsing error. - return true; - } - continue; - } - - frame_type >>= 1; - - switch (frame_type) { - // STREAM_FRAME, ACK_FRAME, and CONGESTION_FEEDBACK_FRAME are handled - // above. - case PADDING_FRAME: - // We're done with the packet - return true; - - case RST_STREAM_FRAME: { - QuicRstStreamFrame frame; - if (!ProcessRstStreamFrame(&frame)) { - return RaiseError(QUIC_INVALID_RST_STREAM_DATA); + // TODO(jri): Remove this entire if block when support for + // QUIC version < 10 removed. + if (version() < QUIC_VERSION_10) { + // Special frame type processing for QUIC version < 10. + if ((frame_type & kQuicFrameType0BitMask) == 0) { + QuicStreamFrame frame; + if (!ProcessStreamFrame(frame_type, &frame)) { + return RaiseError(QUIC_INVALID_STREAM_DATA); } - if (!visitor_->OnRstStreamFrame(frame)) { + if (!visitor_->OnStreamFrame(frame)) { DLOG(INFO) << "Visitor asked to stop further processing."; // Returning true since there was no parsing error. return true; @@ -950,19 +949,27 @@ bool QuicFramer::ProcessFrameData() { continue; } - case CONNECTION_CLOSE_FRAME: { - QuicConnectionCloseFrame frame; - if (!ProcessConnectionCloseFrame(&frame)) { - return RaiseError(QUIC_INVALID_CONNECTION_CLOSE_DATA); + frame_type >>= 1; + if ((frame_type & kQuicFrameType0BitMask) == 0) { + QuicAckFrame frame; + if (!ProcessAckFrame(&frame)) { + return RaiseError(QUIC_INVALID_ACK_DATA); } - - if (!visitor_->OnAckFrame(frame.ack_frame)) { + if (!visitor_->OnAckFrame(frame)) { DLOG(INFO) << "Visitor asked to stop further processing."; // Returning true since there was no parsing error. return true; } + continue; + } - if (!visitor_->OnConnectionCloseFrame(frame)) { + frame_type >>= 1; + if ((frame_type & kQuicFrameType0BitMask) == 0) { + QuicCongestionFeedbackFrame frame; + if (!ProcessQuicCongestionFeedbackFrame(&frame)) { + return RaiseError(QUIC_INVALID_CONGESTION_FEEDBACK_DATA); + } + if (!visitor_->OnCongestionFeedbackFrame(frame)) { DLOG(INFO) << "Visitor asked to stop further processing."; // Returning true since there was no parsing error. return true; @@ -970,23 +977,178 @@ bool QuicFramer::ProcessFrameData() { continue; } - case GOAWAY_FRAME: { - QuicGoAwayFrame goaway_frame; - if (!ProcessGoAwayFrame(&goaway_frame)) { - return RaiseError(QUIC_INVALID_GOAWAY_DATA); - } - if (!visitor_->OnGoAwayFrame(goaway_frame)) { - DLOG(INFO) << "Visitor asked to stop further processing."; - // Returning true since there was no parsing error. + frame_type >>= 1; + switch (frame_type) { + // STREAM_FRAME, ACK_FRAME, and CONGESTION_FEEDBACK_FRAME are handled + // above. + case PADDING_FRAME_OLD: + // We're done with the packet. return true; + + case RST_STREAM_FRAME_OLD: { + QuicRstStreamFrame frame; + if (!ProcessRstStreamFrame(&frame)) { + return RaiseError(QUIC_INVALID_RST_STREAM_DATA); + } + if (!visitor_->OnRstStreamFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; } - continue; + + case CONNECTION_CLOSE_FRAME_OLD: { + QuicConnectionCloseFrame frame; + if (!ProcessConnectionCloseFrame(&frame)) { + return RaiseError(QUIC_INVALID_CONNECTION_CLOSE_DATA); + } + + if (!visitor_->OnAckFrame(frame.ack_frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + + if (!visitor_->OnConnectionCloseFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case GOAWAY_FRAME_OLD: { + QuicGoAwayFrame goaway_frame; + if (!ProcessGoAwayFrame(&goaway_frame)) { + return RaiseError(QUIC_INVALID_GOAWAY_DATA); + } + if (!visitor_->OnGoAwayFrame(goaway_frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + set_detailed_error("Illegal frame type."); + DLOG(WARNING) << "Illegal frame type: " + << static_cast<int>(frame_type); + return RaiseError(QUIC_INVALID_FRAME_DATA); } + } else { + // TODO(jri): Retain this else block when support for + // QUIC version < 10 removed. Remove above if block. + + // Special frame type processing for QUIC version >= 10. + if (frame_type & kQuicFrameTypeSpecialMask) { + // Stream Frame + if (frame_type & kQuicFrameTypeStreamMask) { + QuicStreamFrame frame; + if (!ProcessStreamFrame(frame_type, &frame)) { + return RaiseError(QUIC_INVALID_STREAM_DATA); + } + if (!visitor_->OnStreamFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } - set_detailed_error("Illegal frame type."); - DLOG(WARNING) << "Illegal frame type: " - << static_cast<int>(frame_type); - return RaiseError(QUIC_INVALID_FRAME_DATA); + // Ack Frame + if (frame_type & kQuicFrameTypeAckMask) { + QuicAckFrame frame; + if (!ProcessAckFrame(&frame)) { + return RaiseError(QUIC_INVALID_ACK_DATA); + } + if (!visitor_->OnAckFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + // Congestion Feedback Frame + if (frame_type & kQuicFrameTypeCongestionFeedbackMask) { + QuicCongestionFeedbackFrame frame; + if (!ProcessQuicCongestionFeedbackFrame(&frame)) { + return RaiseError(QUIC_INVALID_CONGESTION_FEEDBACK_DATA); + } + if (!visitor_->OnCongestionFeedbackFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + // This was a special frame type that did not match any + // of the known ones. Error. + set_detailed_error("Illegal frame type."); + DLOG(WARNING) << "Illegal frame type: " + << static_cast<int>(frame_type); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + + switch (frame_type) { + case PADDING_FRAME: + // We're done with the packet. + return true; + + case RST_STREAM_FRAME: { + QuicRstStreamFrame frame; + if (!ProcessRstStreamFrame(&frame)) { + return RaiseError(QUIC_INVALID_RST_STREAM_DATA); + } + if (!visitor_->OnRstStreamFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case CONNECTION_CLOSE_FRAME: { + QuicConnectionCloseFrame frame; + if (!ProcessConnectionCloseFrame(&frame)) { + return RaiseError(QUIC_INVALID_CONNECTION_CLOSE_DATA); + } + + if (!visitor_->OnAckFrame(frame.ack_frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + + if (!visitor_->OnConnectionCloseFrame(frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + case GOAWAY_FRAME: { + QuicGoAwayFrame goaway_frame; + if (!ProcessGoAwayFrame(&goaway_frame)) { + return RaiseError(QUIC_INVALID_GOAWAY_DATA); + } + if (!visitor_->OnGoAwayFrame(goaway_frame)) { + DLOG(INFO) << "Visitor asked to stop further processing."; + // Returning true since there was no parsing error. + return true; + } + continue; + } + + default: + set_detailed_error("Illegal frame type."); + DLOG(WARNING) << "Illegal frame type: " + << static_cast<int>(frame_type); + return RaiseError(QUIC_INVALID_FRAME_DATA); + } } } @@ -995,7 +1157,15 @@ bool QuicFramer::ProcessFrameData() { bool QuicFramer::ProcessStreamFrame(uint8 frame_type, QuicStreamFrame* frame) { - uint8 stream_flags = frame_type >> 1; + uint8 stream_flags = frame_type; + + // TODO(jri): Remove if block after support for ver. < 10 removed. + if (version() < QUIC_VERSION_10) { + stream_flags >>= 1; + } else { + stream_flags &= ~kQuicFrameTypeStreamMask; + } + // Read from right to left: StreamID, Offset, Data Length, Fin. const uint8 stream_id_length = (stream_flags & kQuicStreamIDLengthMask) + 1; stream_flags >>= kQuicStreamIdShift; @@ -1543,22 +1713,44 @@ bool QuicFramer::AppendTypeByte(const QuicFrame& frame, type_byte <<= kQuicStreamIdShift; type_byte |= GetStreamIdSize(frame.stream_frame->stream_id) - 1; - type_byte <<= 1; // Leaves the last bit as a 0. + // TODO(jri): Remove if block when support for QUIC ver. < 10 removed. + if (version() < QUIC_VERSION_10) { + type_byte <<= 1; // Leaves the last bit as a 0. + } else { + type_byte |= kQuicFrameTypeStreamMask; // Set Stream Frame Type to 1. + } break; } case ACK_FRAME: { // TODO(ianswett): Use extra 5 bits in the ack framing. - type_byte = 0x01; + // TODO(jri): Remove if block when support for QUIC ver. < 10 removed. + if (version() < QUIC_VERSION_10) { + type_byte = 0x01; + } else { + type_byte = kQuicFrameTypeAckMask; + } break; } case CONGESTION_FEEDBACK_FRAME: { // TODO(ianswett): Use extra 5 bits in the congestion feedback framing. - type_byte = 0x03; + // TODO(jri): Remove if block when support for QUIC ver. < 10 removed. + if (version() < QUIC_VERSION_10) { + type_byte = 0x03; + } else { + type_byte = kQuicFrameTypeCongestionFeedbackMask; + } break; } default: - type_byte = - frame.type << kQuicDefaultFrameTypeShift | kQuicDefaultFrameTypeMask; + type_byte = frame.type; + // TODO(jri): Remove if block when support for QUIC ver. < 10 removed. + if (version() < QUIC_VERSION_10) { + if (type_byte > 0) { + type_byte += 3; + } + type_byte = (type_byte << kQuicDefaultFrameTypeShift) | + kQuicDefaultFrameTypeMask; + } break; } diff --git a/chromium/net/quic/quic_framer_test.cc b/chromium/net/quic/quic_framer_test.cc index e4148fea31d..b3bdbac35f4 100644 --- a/chromium/net/quic/quic_framer_test.cc +++ b/chromium/net/quic/quic_framer_test.cc @@ -36,8 +36,6 @@ namespace test { const QuicPacketSequenceNumber kEpoch = GG_UINT64_C(1) << 48; const QuicPacketSequenceNumber kMask = kEpoch - 1; -// Index into the flags offset in the header. -const size_t kPublicFlagsOffset = 0; // Index into the guid offset in the header. const size_t kGuidOffset = kPublicFlagsSize; // Index into the version string in the header. (if present). @@ -321,6 +319,18 @@ class QuicFramerTest : public ::testing::TestWithParam<QuicVersion> { framer_.set_version(version_); } + // Helper function to get unsigned char representation of digit in the + // units place of the current QUIC version number. + unsigned char GetQuicVersionDigitOnes() { + return static_cast<unsigned char> ('0' + version_%10); + } + + // Helper function to get unsigned char representation of digit in the + // tens place of the current QUIC version number. + unsigned char GetQuicVersionDigitTens() { + return static_cast<unsigned char> ('0' + (version_/10)%10); + } + bool CheckEncryption(QuicPacketSequenceNumber sequence_number, QuicPacket* packet) { if (sequence_number != encrypter_->sequence_number_) { @@ -390,11 +400,28 @@ class QuicFramerTest : public ::testing::TestWithParam<QuicVersion> { EXPECT_EQ(error_code, framer_.error()) << "len: " << len; } - void ValidateTruncatedAck(const QuicAckFrame* ack, size_t keys) { - for (size_t i = 1; i < keys; ++i) { - EXPECT_TRUE(ContainsKey(ack->received_info.missing_packets, i)) << i; + void CheckStreamFrameBoundaries(unsigned char* packet, + size_t stream_id_size, + bool include_version) { + // Now test framing boundaries + for (size_t i = kQuicFrameTypeSize; + i < GetMinStreamFrameSize(framer_.version()); ++i) { + string expected_error; + if (i < kQuicFrameTypeSize + stream_id_size) { + expected_error = "Unable to read stream_id."; + } else if (i < kQuicFrameTypeSize + stream_id_size + + kQuicMaxStreamOffsetSize) { + expected_error = "Unable to read offset."; + } else { + expected_error = "Unable to read frame data."; + } + CheckProcessingFails( + packet, + i + GetPacketHeaderSize(PACKET_8BYTE_GUID, include_version, + PACKET_6BYTE_SEQUENCE_NUMBER, + NOT_IN_FEC_GROUP), + expected_error, QUIC_INVALID_STREAM_DATA); } - EXPECT_EQ(keys, ack->received_info.largest_observed); } void CheckCalculatePacketSequenceNumber( @@ -580,7 +607,7 @@ TEST_P(QuicFramerTest, PacketHeader) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -633,7 +660,7 @@ TEST_P(QuicFramerTest, PacketHeaderWith4ByteGuid) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -688,7 +715,7 @@ TEST_P(QuicFramerTest, PacketHeader1ByteGuid) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -741,7 +768,7 @@ TEST_P(QuicFramerTest, PacketHeaderWith0ByteGuid) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -778,9 +805,6 @@ TEST_P(QuicFramerTest, PacketHeaderWith0ByteGuid) { } TEST_P(QuicFramerTest, PacketHeaderWithVersionFlag) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (version) 0x3D, @@ -788,7 +812,7 @@ TEST_P(QuicFramerTest, PacketHeaderWithVersionFlag) { 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, // version tag - 'Q', '0', '0', '7', + 'Q', '0', GetQuicVersionDigitTens(), GetQuicVersionDigitOnes(), // packet sequence number 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12, @@ -798,13 +822,13 @@ TEST_P(QuicFramerTest, PacketHeaderWithVersionFlag) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); EXPECT_FALSE(visitor_.header_->public_header.reset_flag); EXPECT_TRUE(visitor_.header_->public_header.version_flag); - EXPECT_EQ(QUIC_VERSION_7, visitor_.header_->public_header.versions[0]); + EXPECT_EQ(GetParam(), visitor_.header_->public_header.versions[0]); EXPECT_FALSE(visitor_.header_->fec_flag); EXPECT_FALSE(visitor_.header_->entropy_flag); EXPECT_EQ(0, visitor_.header_->entropy_hash); @@ -854,7 +878,7 @@ TEST_P(QuicFramerTest, PacketHeaderWith4ByteSequenceNumber) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -909,7 +933,7 @@ TEST_P(QuicFramerTest, PacketHeaderWith2ByteSequenceNumber) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -964,7 +988,7 @@ TEST_P(QuicFramerTest, PacketHeaderWith1ByteSequenceNumber) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_MISSING_PAYLOAD, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(GG_UINT64_C(0xFEDCBA9876543210), visitor_.header_->public_header.guid); @@ -1014,23 +1038,10 @@ TEST_P(QuicFramerTest, InvalidPublicFlag) { // private flags 0x00, - // frame count - 0x01, - // frame type (stream frame) - 0x01, - // stream id - 0x04, 0x03, 0x02, 0x01, - // fin - 0x01, - // offset - 0x54, 0x76, 0x10, 0x32, - 0xDC, 0xFE, 0x98, 0xBA, - // data length - 0x0c, 0x00, - // data - 'h', 'e', 'l', 'l', - 'o', ' ', 'w', 'o', - 'r', 'l', 'd', '!', + // frame type (padding) + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; CheckProcessingFails(packet, arraysize(packet), @@ -1039,9 +1050,6 @@ TEST_P(QuicFramerTest, InvalidPublicFlag) { }; TEST_P(QuicFramerTest, InvalidPublicFlagWithMatchingVersions) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (8 byte guid and version flag and an unknown flag) 0x4D, @@ -1049,30 +1057,17 @@ TEST_P(QuicFramerTest, InvalidPublicFlagWithMatchingVersions) { 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, // version tag - 'Q', '0', '0', '7', + 'Q', '0', GetQuicVersionDigitTens(), GetQuicVersionDigitOnes(), // packet sequence number 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12, // private flags 0x00, - // frame count - 0x01, - // frame type (stream frame) - 0x01, - // stream id - 0x04, 0x03, 0x02, 0x01, - // fin - 0x01, - // offset - 0x54, 0x76, 0x10, 0x32, - 0xDC, 0xFE, 0x98, 0xBA, - // data length - 0x0c, 0x00, - // data - 'h', 'e', 'l', 'l', - 'o', ' ', 'w', 'o', - 'r', 'l', 'd', '!', + // frame type (padding) + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; CheckProcessingFails(packet, arraysize(packet), @@ -1096,7 +1091,9 @@ TEST_P(QuicFramerTest, LargePublicFlagWithMismatchedVersions) { 0x00, // frame type (padding frame) - 0x07, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); @@ -1119,23 +1116,10 @@ TEST_P(QuicFramerTest, InvalidPrivateFlag) { // private flags 0x10, - // frame count - 0x01, - // frame type (stream frame) - 0x01, - // stream id - 0x04, 0x03, 0x02, 0x01, - // fin - 0x01, - // offset - 0x54, 0x76, 0x10, 0x32, - 0xDC, 0xFE, 0x98, 0xBA, - // data length - 0x0c, 0x00, - // data - 'h', 'e', 'l', 'l', - 'o', ' ', 'w', 'o', - 'r', 'l', 'd', '!', + // frame type (padding) + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; CheckProcessingFails(packet, arraysize(packet), @@ -1179,14 +1163,20 @@ TEST_P(QuicFramerTest, PaddingFrame) { 0x00, // frame type (padding frame) - 0x07, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), // Ignored data (which in this case is a stream frame) - 0x01, + // frame type (stream frame with fin) + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), + // stream id 0x04, 0x03, 0x02, 0x01, - 0x01, + // offset 0x54, 0x76, 0x10, 0x32, 0xDC, 0xFE, 0x98, 0xBA, + // data length 0x0c, 0x00, + // data 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', @@ -1205,7 +1195,7 @@ TEST_P(QuicFramerTest, PaddingFrame) { packet, GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - "Unable to read frame type.", QUIC_INVALID_FRAME_DATA); + "Packet has no frames.", QUIC_MISSING_PAYLOAD); } TEST_P(QuicFramerTest, StreamFrame) { @@ -1222,7 +1212,8 @@ TEST_P(QuicFramerTest, StreamFrame) { 0x00, // frame type (stream frame with fin) - 0xFE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -1253,26 +1244,7 @@ TEST_P(QuicFramerTest, StreamFrame) { EXPECT_EQ("hello world!", visitor_.stream_frames_[0]->data); // Now test framing boundaries - for (size_t i = 0; i < GetMinStreamFrameSize(framer_.version()); ++i) { - string expected_error; - if (i < kQuicFrameTypeSize) { - expected_error = "Unable to read frame type."; - } else if (i < kQuicFrameTypeSize + kQuicMaxStreamIdSize) { - expected_error = "Unable to read stream_id."; - } else if (i < kQuicFrameTypeSize + kQuicMaxStreamIdSize) { - expected_error = "Unable to read fin."; - } else if (i < kQuicFrameTypeSize + kQuicMaxStreamIdSize + - kQuicMaxStreamOffsetSize) { - expected_error = "Unable to read offset."; - } else { - expected_error = "Unable to read frame data."; - } - CheckProcessingFails( - packet, - i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, - PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); - } + CheckStreamFrameBoundaries(packet, kQuicMaxStreamIdSize, !kIncludeVersion); } TEST_P(QuicFramerTest, StreamFrame3ByteStreamId) { @@ -1289,7 +1261,8 @@ TEST_P(QuicFramerTest, StreamFrame3ByteStreamId) { 0x00, // frame type (stream frame with fin) - 0xFC, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFC : 0xFE), // stream id 0x04, 0x03, 0x02, // offset @@ -1312,7 +1285,7 @@ TEST_P(QuicFramerTest, StreamFrame3ByteStreamId) { ASSERT_EQ(1u, visitor_.stream_frames_.size()); EXPECT_EQ(0u, visitor_.ack_frames_.size()); - EXPECT_EQ(static_cast<uint64>(0x00020304), + EXPECT_EQ(GG_UINT64_C(0x00020304), visitor_.stream_frames_[0]->stream_id); EXPECT_TRUE(visitor_.stream_frames_[0]->fin); EXPECT_EQ(GG_UINT64_C(0xBA98FEDC32107654), @@ -1321,26 +1294,7 @@ TEST_P(QuicFramerTest, StreamFrame3ByteStreamId) { // Now test framing boundaries const size_t stream_id_size = 3; - for (size_t i = 0; i < GetMinStreamFrameSize(framer_.version()); ++i) { - string expected_error; - if (i < kQuicFrameTypeSize) { - expected_error = "Unable to read frame type."; - } else if (i < kQuicFrameTypeSize + stream_id_size) { - expected_error = "Unable to read stream_id."; - } else if (i < kQuicFrameTypeSize + stream_id_size - 1) { - expected_error = "Unable to read fin."; - } else if (i < kQuicFrameTypeSize + stream_id_size + - kQuicMaxStreamOffsetSize) { - expected_error = "Unable to read offset."; - } else { - expected_error = "Unable to read frame data."; - } - CheckProcessingFails( - packet, - i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, - PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); - } + CheckStreamFrameBoundaries(packet, stream_id_size, !kIncludeVersion); } TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { @@ -1357,7 +1311,8 @@ TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { 0x00, // frame type (stream frame with fin) - 0xFA, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFA : 0xFD), // stream id 0x04, 0x03, // offset @@ -1389,26 +1344,7 @@ TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { // Now test framing boundaries const size_t stream_id_size = 2; - for (size_t i = 0; i < GetMinStreamFrameSize(framer_.version()); ++i) { - string expected_error; - if (i < kQuicFrameTypeSize) { - expected_error = "Unable to read frame type."; - } else if (i < kQuicFrameTypeSize + stream_id_size) { - expected_error = "Unable to read stream_id."; - } else if (i < kQuicFrameTypeSize + stream_id_size - 1) { - expected_error = "Unable to read fin."; - } else if (i < kQuicFrameTypeSize + stream_id_size + - kQuicMaxStreamOffsetSize) { - expected_error = "Unable to read offset."; - } else { - expected_error = "Unable to read frame data."; - } - CheckProcessingFails( - packet, - i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, - PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); - } + CheckStreamFrameBoundaries(packet, stream_id_size, !kIncludeVersion); } TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { @@ -1425,7 +1361,8 @@ TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { 0x00, // frame type (stream frame with fin) - 0xF8, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xF8 : 0xFC), // stream id 0x04, // offset @@ -1457,32 +1394,10 @@ TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { // Now test framing boundaries const size_t stream_id_size = 1; - for (size_t i = 0; i < GetMinStreamFrameSize(framer_.version()); ++i) { - string expected_error; - if (i < kQuicFrameTypeSize) { - expected_error = "Unable to read frame type."; - } else if (i < kQuicFrameTypeSize + stream_id_size) { - expected_error = "Unable to read stream_id."; - } else if (i < kQuicFrameTypeSize + stream_id_size - 1) { - expected_error = "Unable to read fin."; - } else if (i < kQuicFrameTypeSize + stream_id_size + - kQuicMaxStreamOffsetSize) { - expected_error = "Unable to read offset."; - } else { - expected_error = "Unable to read frame data."; - } - CheckProcessingFails( - packet, - i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, - PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); - } + CheckStreamFrameBoundaries(packet, stream_id_size, !kIncludeVersion); } TEST_P(QuicFramerTest, StreamFrameWithVersion) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (version, 8 byte guid) 0x3D, @@ -1490,7 +1405,7 @@ TEST_P(QuicFramerTest, StreamFrameWithVersion) { 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, // version tag - 'Q', '0', '0', '7', + 'Q', '0', GetQuicVersionDigitTens(), GetQuicVersionDigitOnes(), // packet sequence number 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12, @@ -1498,7 +1413,8 @@ TEST_P(QuicFramerTest, StreamFrameWithVersion) { 0x00, // frame type (stream frame with fin) - 0xFE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -1518,7 +1434,7 @@ TEST_P(QuicFramerTest, StreamFrameWithVersion) { EXPECT_EQ(QUIC_NO_ERROR, framer_.error()); ASSERT_TRUE(visitor_.header_.get()); EXPECT_TRUE(visitor_.header_.get()->public_header.version_flag); - EXPECT_EQ(QUIC_VERSION_7, visitor_.header_.get()->public_header.versions[0]); + EXPECT_EQ(GetParam(), visitor_.header_.get()->public_header.versions[0]); EXPECT_TRUE(CheckDecryption(encrypted, kIncludeVersion)); ASSERT_EQ(1u, visitor_.stream_frames_.size()); @@ -1531,24 +1447,7 @@ TEST_P(QuicFramerTest, StreamFrameWithVersion) { EXPECT_EQ("hello world!", visitor_.stream_frames_[0]->data); // Now test framing boundaries - for (size_t i = 0; i < GetMinStreamFrameSize(framer_.version()); ++i) { - string expected_error; - if (i < kQuicFrameTypeSize) { - expected_error = "Unable to read frame type."; - } else if (i < kQuicFrameTypeSize + kQuicMaxStreamIdSize) { - expected_error = "Unable to read stream_id."; - } else if (i < kQuicFrameTypeSize + kQuicMaxStreamIdSize + - kQuicMaxStreamOffsetSize) { - expected_error = "Unable to read offset."; - } else { - expected_error = "Unable to read frame data."; - } - CheckProcessingFails( - packet, - i + GetPacketHeaderSize(PACKET_8BYTE_GUID, kIncludeVersion, - PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); - } + CheckStreamFrameBoundaries(packet, kQuicMaxStreamIdSize, kIncludeVersion); } TEST_P(QuicFramerTest, RejectPacket) { @@ -1567,7 +1466,8 @@ TEST_P(QuicFramerTest, RejectPacket) { 0x00, // frame type (stream frame with fin) - 0xFE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -1595,7 +1495,8 @@ TEST_P(QuicFramerTest, RejectPacket) { TEST_P(QuicFramerTest, RevivedStreamFrame) { unsigned char payload[] = { // frame type (stream frame with fin) - 0xFE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -1649,9 +1550,6 @@ TEST_P(QuicFramerTest, RevivedStreamFrame) { } TEST_P(QuicFramerTest, StreamFrameInFecGroup) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (8 byte guid) 0x3C, @@ -1667,7 +1565,8 @@ TEST_P(QuicFramerTest, StreamFrameInFecGroup) { 0x02, // frame type (stream frame with fin) - 0xFE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -1719,7 +1618,8 @@ TEST_P(QuicFramerTest, AckFrame) { 0x00, // frame type (ack frame) - static_cast<unsigned char>(0x01), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x01 : 0x40), // entropy hash of sent packets till least awaiting - 1. 0xAB, // least packet sequence number awaiting an ack @@ -1772,12 +1672,10 @@ TEST_P(QuicFramerTest, AckFrame) { kNumberOfMissingPacketsSize; // Now test framing boundaries const size_t missing_packets_size = 1 * PACKET_6BYTE_SEQUENCE_NUMBER; - for (size_t i = 0; + for (size_t i = kQuicFrameTypeSize; i < QuicFramer::GetMinAckFrameSize() + missing_packets_size; ++i) { string expected_error; - if (i < kSentEntropyOffset) { - expected_error = "Unable to read frame type."; - } else if (i < kLeastUnackedOffset) { + if (i < kLeastUnackedOffset) { expected_error = "Unable to read entropy hash for sent packets."; } else if (i < kReceivedEntropyOffset) { expected_error = "Unable to read least unacked."; @@ -1796,7 +1694,7 @@ TEST_P(QuicFramerTest, AckFrame) { packet, i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); + expected_error, QUIC_INVALID_ACK_DATA); } } @@ -1814,7 +1712,8 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameTCP) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (tcp) 0x00, // ack_frame.feedback.tcp.accumulated_number_of_lost_packets @@ -1840,11 +1739,9 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameTCP) { EXPECT_EQ(0x4030u, frame.tcp.receive_window); // Now test framing boundaries - for (size_t i = 0; i < 6; ++i) { + for (size_t i = kQuicFrameTypeSize; i < 6; ++i) { string expected_error; - if (i < 1) { - expected_error = "Unable to read frame type."; - } else if (i < 2) { + if (i < 2) { expected_error = "Unable to read congestion feedback type."; } else if (i < 4) { expected_error = "Unable to read accumulated number of lost packets."; @@ -1855,7 +1752,7 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameTCP) { packet, i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); + expected_error, QUIC_INVALID_CONGESTION_FEEDBACK_DATA); } } @@ -1873,7 +1770,8 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameInterArrival) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (inter arrival) 0x01, // accumulated_number_of_lost_packets @@ -1926,11 +1824,9 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameInterArrival) { iter->second.Subtract(start_).ToMicroseconds()); // Now test framing boundaries - for (size_t i = 0; i < 31; ++i) { + for (size_t i = kQuicFrameTypeSize; i < 31; ++i) { string expected_error; - if (i < 1) { - expected_error = "Unable to read frame type."; - } else if (i < 2) { + if (i < 2) { expected_error = "Unable to read congestion feedback type."; } else if (i < 4) { expected_error = "Unable to read accumulated number of lost packets."; @@ -1953,7 +1849,7 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameInterArrival) { packet, i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); + expected_error, QUIC_INVALID_CONGESTION_FEEDBACK_DATA); } } @@ -1971,7 +1867,8 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameFixRate) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (fix rate) 0x02, // bitrate_in_bytes_per_second; @@ -1994,11 +1891,9 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameFixRate) { frame.fix_rate.bitrate.ToBytesPerSecond()); // Now test framing boundaries - for (size_t i = 0; i < 6; ++i) { + for (size_t i = kQuicFrameTypeSize; i < 6; ++i) { string expected_error; - if (i < 1) { - expected_error = "Unable to read frame type."; - } else if (i < 2) { + if (i < 2) { expected_error = "Unable to read congestion feedback type."; } else if (i < 6) { expected_error = "Unable to read bitrate."; @@ -2007,11 +1902,10 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameFixRate) { packet, i + GetPacketHeaderSize(PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), - expected_error, QUIC_INVALID_FRAME_DATA); + expected_error, QUIC_INVALID_CONGESTION_FEEDBACK_DATA); } } - TEST_P(QuicFramerTest, CongestionFeedbackFrameInvalidFeedback) { unsigned char packet[] = { // public flags (8 byte guid) @@ -2026,7 +1920,8 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameInvalidFeedback) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (invalid) 0x03, }; @@ -2034,7 +1929,7 @@ TEST_P(QuicFramerTest, CongestionFeedbackFrameInvalidFeedback) { QuicEncryptedPacket encrypted(AsChars(packet), arraysize(packet), false); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); EXPECT_TRUE(CheckDecryption(encrypted, !kIncludeVersion)); - EXPECT_EQ(QUIC_INVALID_FRAME_DATA, framer_.error()); + EXPECT_EQ(QUIC_INVALID_CONGESTION_FEEDBACK_DATA, framer_.error()); } TEST_P(QuicFramerTest, RstStreamFrame) { @@ -2051,7 +1946,8 @@ TEST_P(QuicFramerTest, RstStreamFrame) { 0x00, // frame type (rst stream frame) - static_cast<unsigned char>(0x27), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x27 : 0x01), // stream id 0x04, 0x03, 0x02, 0x01, // error code @@ -2078,7 +1974,7 @@ TEST_P(QuicFramerTest, RstStreamFrame) { EXPECT_EQ("because I can", visitor_.rst_stream_frame_.error_details); // Now test framing boundaries - for (size_t i = 2; i < 24; ++i) { + for (size_t i = kQuicFrameTypeSize; i < 24; ++i) { string expected_error; if (i < kQuicFrameTypeSize + kQuicMaxStreamIdSize) { expected_error = "Unable to read stream_id."; @@ -2110,7 +2006,8 @@ TEST_P(QuicFramerTest, ConnectionCloseFrame) { 0x00, // frame type (connection close frame) - static_cast<unsigned char>(0x2F), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x2F : 0x02), // error code 0x11, 0x00, 0x00, 0x00, @@ -2197,7 +2094,8 @@ TEST_P(QuicFramerTest, GoAwayFrame) { 0x00, // frame type (go away frame) - static_cast<unsigned char>(0x37), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x37 : 0x03), // error code 0x09, 0x00, 0x00, 0x00, // stream id @@ -2297,9 +2195,6 @@ TEST_P(QuicFramerTest, PublicResetPacket) { } TEST_P(QuicFramerTest, VersionNegotiationPacket) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (version, 8 byte guid) 0x3D, @@ -2307,7 +2202,7 @@ TEST_P(QuicFramerTest, VersionNegotiationPacket) { 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, // version tag - 'Q', '0', '0', '7', + 'Q', '0', GetQuicVersionDigitTens(), GetQuicVersionDigitOnes(), 'Q', '2', '.', '0', }; @@ -2318,8 +2213,7 @@ TEST_P(QuicFramerTest, VersionNegotiationPacket) { ASSERT_EQ(QUIC_NO_ERROR, framer_.error()); ASSERT_TRUE(visitor_.version_negotiation_packet_.get()); EXPECT_EQ(2u, visitor_.version_negotiation_packet_->versions.size()); - EXPECT_EQ(QUIC_VERSION_7, - visitor_.version_negotiation_packet_->versions[0]); + EXPECT_EQ(GetParam(), visitor_.version_negotiation_packet_->versions[0]); for (size_t i = 0; i <= kPublicFlagsSize + PACKET_8BYTE_GUID; ++i) { string expected_error; @@ -2401,7 +2295,9 @@ TEST_P(QuicFramerTest, BuildPaddingFramePacket) { 0x00, // frame type (padding frame) - static_cast<unsigned char>(0x07), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; uint64 header_size = @@ -2446,7 +2342,9 @@ TEST_P(QuicFramerTest, Build4ByteSequenceNumberPaddingFramePacket) { 0x00, // frame type (padding frame) - static_cast<unsigned char>(0x07), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; uint64 header_size = @@ -2491,7 +2389,9 @@ TEST_P(QuicFramerTest, Build2ByteSequenceNumberPaddingFramePacket) { 0x00, // frame type (padding frame) - static_cast<unsigned char>(0x07), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; uint64 header_size = @@ -2536,7 +2436,9 @@ TEST_P(QuicFramerTest, Build1ByteSequenceNumberPaddingFramePacket) { 0x00, // frame type (padding frame) - static_cast<unsigned char>(0x07), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x07 : 0x00), + 0x00, 0x00, 0x00, 0x00 }; uint64 header_size = @@ -2554,9 +2456,6 @@ TEST_P(QuicFramerTest, Build1ByteSequenceNumberPaddingFramePacket) { } TEST_P(QuicFramerTest, BuildStreamFramePacket) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - QuicPacketHeader header; header.public_header.guid = GG_UINT64_C(0xFEDCBA9876543210); header.public_header.reset_flag = false; @@ -2588,7 +2487,8 @@ TEST_P(QuicFramerTest, BuildStreamFramePacket) { 0x01, // frame type (stream frame with fin and no length) - 0xBE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xBE : 0xDF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -2628,8 +2528,6 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithVersionFlag) { QuicFrames frames; frames.push_back(QuicFrame(&stream_frame)); - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); unsigned char packet[] = { // public flags (version, 8 byte guid) 0x3D, @@ -2637,7 +2535,7 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithVersionFlag) { 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, // version tag - 'Q', '0', '0', '7', + 'Q', '0', GetQuicVersionDigitTens(), GetQuicVersionDigitOnes(), // packet sequence number 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12, @@ -2645,7 +2543,8 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithVersionFlag) { 0x01, // frame type (stream frame with fin and no length) - 0xBE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xBE : 0xDF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -2680,11 +2579,11 @@ TEST_P(QuicFramerTest, BuildVersionNegotiationPacket) { 0x10, 0x32, 0x54, 0x76, 0x98, 0xBA, 0xDC, 0xFE, // version tag - 'Q', '0', '0', '7', + 'Q', '0', GetQuicVersionDigitTens(), GetQuicVersionDigitOnes(), }; QuicVersionVector versions; - versions.push_back(QUIC_VERSION_7); + versions.push_back(GetParam()); scoped_ptr<QuicEncryptedPacket> data( framer_.BuildVersionNegotiationPacket(header, versions)); @@ -2728,7 +2627,8 @@ TEST_P(QuicFramerTest, BuildAckFramePacket) { 0x01, // frame type (ack frame) - static_cast<unsigned char>(0x01), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x01 : 0x40), // entropy hash of sent packets till least awaiting - 1. 0x14, // least packet sequence number awaiting an ack @@ -2788,7 +2688,8 @@ TEST_P(QuicFramerTest, BuildCongestionFeedbackFramePacketTCP) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (TCP) 0x00, // accumulated number of lost packets @@ -2847,7 +2748,8 @@ TEST_P(QuicFramerTest, BuildCongestionFeedbackFramePacketInterArrival) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (inter arrival) 0x01, // accumulated_number_of_lost_packets @@ -2910,7 +2812,8 @@ TEST_P(QuicFramerTest, BuildCongestionFeedbackFramePacketFixRate) { 0x00, // frame type (congestion feedback frame) - 0x03, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x03 : 0x20), // congestion feedback type (fix rate) 0x02, // bitrate_in_bytes_per_second; @@ -2976,7 +2879,8 @@ TEST_P(QuicFramerTest, BuildRstFramePacket) { 0x00, // frame type (rst stream frame) - static_cast<unsigned char>(0x27), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x27 : 0x01), // stream id 0x04, 0x03, 0x02, 0x01, // error code @@ -3039,7 +2943,8 @@ TEST_P(QuicFramerTest, BuildCloseFramePacket) { 0x01, // frame type (connection close frame) - static_cast<unsigned char>(0x2F), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x2F : 0x02), // error code 0x08, 0x07, 0x06, 0x05, // error details length @@ -3110,7 +3015,8 @@ TEST_P(QuicFramerTest, BuildGoAwayPacket) { 0x01, // frame type (go away frame) - static_cast<unsigned char>(0x37), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x37 : 0x03), // error code 0x08, 0x07, 0x06, 0x05, // stream id @@ -3438,9 +3344,6 @@ TEST_P(QuicFramerTest, CleanTruncation) { } TEST_P(QuicFramerTest, EntropyFlagTest) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (8 byte guid) 0x3C, @@ -3454,7 +3357,8 @@ TEST_P(QuicFramerTest, EntropyFlagTest) { 0x01, // frame type (stream frame with fin and no length) - 0xBE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xBE : 0xDF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -3476,9 +3380,6 @@ TEST_P(QuicFramerTest, EntropyFlagTest) { }; TEST_P(QuicFramerTest, FecEntropyTest) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (8 byte guid) 0x3C, @@ -3494,7 +3395,8 @@ TEST_P(QuicFramerTest, FecEntropyTest) { 0xFF, // frame type (stream frame with fin and no length) - 0xBE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xBE : 0xDF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -3516,9 +3418,6 @@ TEST_P(QuicFramerTest, FecEntropyTest) { }; TEST_P(QuicFramerTest, StopPacketProcessing) { - // Set a specific version. - framer_.set_version(QUIC_VERSION_7); - unsigned char packet[] = { // public flags (8 byte guid) 0x3C, @@ -3532,7 +3431,8 @@ TEST_P(QuicFramerTest, StopPacketProcessing) { 0x01, // frame type (stream frame with fin) - 0xFE, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0xFE : 0xFF), // stream id 0x04, 0x03, 0x02, 0x01, // offset @@ -3546,7 +3446,8 @@ TEST_P(QuicFramerTest, StopPacketProcessing) { 'r', 'l', 'd', '!', // frame type (ack frame) - 0x02, + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x01 : 0x40), // entropy hash of sent packets till least awaiting - 1. 0x14, // least packet sequence number awaiting an ack @@ -3591,7 +3492,8 @@ TEST_P(QuicFramerTest, ConnectionCloseWithInvalidAck) { 0x00, // frame type (connection close frame) - static_cast<unsigned char>(0x2F), + static_cast<unsigned char>( + GetParam() < QUIC_VERSION_10 ? 0x2F : 0x02), // error code 0x11, 0x00, 0x00, 0x00, // error details length diff --git a/chromium/net/quic/quic_http_stream.cc b/chromium/net/quic/quic_http_stream.cc index 94d1a2e56ba..0122c64b1ad 100644 --- a/chromium/net/quic/quic_http_stream.cc +++ b/chromium/net/quic/quic_http_stream.cc @@ -11,6 +11,7 @@ #include "net/http/http_response_headers.h" #include "net/http/http_util.h" #include "net/quic/quic_client_session.h" +#include "net/quic/quic_http_utils.h" #include "net/quic/quic_reliable_client_stream.h" #include "net/quic/quic_utils.h" #include "net/socket/next_proto.h" @@ -29,6 +30,7 @@ QuicHttpStream::QuicHttpStream(const base::WeakPtr<QuicClientSession> session) stream_(NULL), request_info_(NULL), request_body_stream_(NULL), + priority_(MINIMUM_PRIORITY), response_info_(NULL), response_status_(OK), response_headers_received_(false), @@ -52,6 +54,7 @@ int QuicHttpStream::InitializeStream(const HttpRequestInfo* request_info, stream_net_log_ = stream_net_log; request_info_ = request_info; + priority_ = priority; int rv = stream_request_.StartRequest( session_, &stream_, base::Bind(&QuicHttpStream::OnStreamReady, @@ -82,11 +85,18 @@ int QuicHttpStream::SendRequest(const HttpRequestHeaders& request_headers, CHECK(!callback.is_null()); CHECK(response); + QuicPriority priority = ConvertRequestPriorityToQuicPriority(priority_); + stream_->set_priority(priority); // Store the serialized request headers. SpdyHeaderBlock headers; CreateSpdyHeadersFromHttpRequest(*request_info_, request_headers, &headers, 3, /*direct=*/true); - request_ = stream_->compressor()->CompressHeaders(headers); + if (session_->connection()->version() < QUIC_VERSION_9) { + request_ = stream_->compressor()->CompressHeaders(headers); + } else { + request_ = stream_->compressor()->CompressHeadersWithPriority(priority, + headers); + } // Log the actual request with the URL Request's net log. stream_net_log_.AddEvent( NetLog::TYPE_HTTP_TRANSACTION_SPDY_SEND_REQUEST_HEADERS, @@ -201,7 +211,9 @@ void QuicHttpStream::Close(bool not_reusable) { // Note: the not_reusable flag has no meaning for SPDY streams. if (stream_) { stream_->SetDelegate(NULL); - stream_->Close(QUIC_STREAM_NO_ERROR); + // TODO(rch): use new CANCELLED error code here once quic 11 + // is everywhere. + stream_->Close(QUIC_ERROR_PROCESSING_STREAM); stream_ = NULL; } } @@ -257,6 +269,10 @@ void QuicHttpStream::Drain(HttpNetworkSession* session) { delete this; } +void QuicHttpStream::SetPriority(RequestPriority priority) { + priority_ = priority; +} + int QuicHttpStream::OnSendData() { // TODO(rch): Change QUIC IO to provide notifications to the streams. NOTREACHED(); @@ -275,7 +291,10 @@ int QuicHttpStream::OnDataReceived(const char* data, int length) { if (!response_headers_received_) { // Grow the read buffer if necessary. if (read_buf_->RemainingCapacity() < length) { - read_buf_->SetCapacity(read_buf_->capacity() + kHeaderBufInitialSize); + size_t additional_capacity = length - read_buf_->RemainingCapacity(); + if (additional_capacity < kHeaderBufInitialSize) + additional_capacity = kHeaderBufInitialSize; + read_buf_->SetCapacity(read_buf_->capacity() + additional_capacity); } memcpy(read_buf_->data(), data, length); read_buf_->set_offset(read_buf_->offset() + length); @@ -323,6 +342,10 @@ void QuicHttpStream::OnError(int error) { DoCallback(response_status_); } +bool QuicHttpStream::HasSendHeadersComplete() { + return next_state_ > STATE_SEND_HEADERS_COMPLETE; +} + void QuicHttpStream::OnIOComplete(int rv) { rv = DoLoop(rv); @@ -386,8 +409,9 @@ int QuicHttpStream::DoSendHeaders() { bool has_upload_data = request_body_stream_ != NULL; next_state_ = STATE_SEND_HEADERS_COMPLETE; - QuicConsumedData rv = stream_->WriteData(request_, !has_upload_data); - return rv.bytes_consumed; + return stream_->WriteStreamData( + request_, !has_upload_data, + base::Bind(&QuicHttpStream::OnIOComplete, weak_factory_.GetWeakPtr())); } int QuicHttpStream::DoSendHeadersComplete(int rv) { @@ -432,18 +456,14 @@ int QuicHttpStream::DoSendBody() { const bool eof = request_body_stream_->IsEOF(); int len = request_body_buf_->BytesRemaining(); if (len > 0 || eof) { - base::StringPiece data(request_body_buf_->data(), len); - QuicConsumedData rv = stream_->WriteData(data, eof); - request_body_buf_->DidConsume(rv.bytes_consumed); - if (eof) { - next_state_ = STATE_OPEN; - return OK; - } next_state_ = STATE_SEND_BODY_COMPLETE; - return rv.bytes_consumed; + base::StringPiece data(request_body_buf_->data(), len); + return stream_->WriteStreamData( + data, eof, + base::Bind(&QuicHttpStream::OnIOComplete, weak_factory_.GetWeakPtr())); } - next_state_ = STATE_SEND_BODY_COMPLETE; + next_state_ = STATE_OPEN; return OK; } @@ -451,7 +471,14 @@ int QuicHttpStream::DoSendBodyComplete(int rv) { if (rv < 0) return rv; - next_state_ = STATE_READ_REQUEST_BODY; + request_body_buf_->DidConsume(request_body_buf_->BytesRemaining()); + + if (!request_body_stream_->IsEOF()) { + next_state_ = STATE_READ_REQUEST_BODY; + return OK; + } + + next_state_ = STATE_OPEN; return OK; } diff --git a/chromium/net/quic/quic_http_stream.h b/chromium/net/quic/quic_http_stream.h index 85dc5537480..71fb5153352 100644 --- a/chromium/net/quic/quic_http_stream.h +++ b/chromium/net/quic/quic_http_stream.h @@ -15,6 +15,10 @@ namespace net { +namespace test { +class QuicHttpStreamPeer; +} // namespace test + // The QuicHttpStream is a QUIC-specific HttpStream subclass. It holds a // non-owning pointer to a QuicReliableClientStream which it uses to // send and receive data. @@ -54,6 +58,7 @@ class NET_EXPORT_PRIVATE QuicHttpStream : SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual bool IsSpdyHttpStream() const OVERRIDE; virtual void Drain(HttpNetworkSession* session) OVERRIDE; + virtual void SetPriority(RequestPriority priority) OVERRIDE; // QuicReliableClientStream::Delegate implementation virtual int OnSendData() OVERRIDE; @@ -61,8 +66,11 @@ class NET_EXPORT_PRIVATE QuicHttpStream : virtual int OnDataReceived(const char* data, int length) OVERRIDE; virtual void OnClose(QuicErrorCode error) OVERRIDE; virtual void OnError(int error) OVERRIDE; + virtual bool HasSendHeadersComplete() OVERRIDE; private: + friend class test::QuicHttpStreamPeer; + enum State { STATE_NONE, STATE_SEND_HEADERS, @@ -105,6 +113,8 @@ class NET_EXPORT_PRIVATE QuicHttpStream : const HttpRequestInfo* request_info_; // The request body to send, if any, owned by the caller. UploadDataStream* request_body_stream_; + // The priority of the request. + RequestPriority priority_; // |response_info_| is the HTTP response data object which is filled in // when a the response headers are read. It is not owned by this stream. HttpResponseInfo* response_info_; diff --git a/chromium/net/quic/quic_http_stream_test.cc b/chromium/net/quic/quic_http_stream_test.cc index b3784168a70..7ccab6d8e1e 100644 --- a/chromium/net/quic/quic_http_stream_test.cc +++ b/chromium/net/quic/quic_http_stream_test.cc @@ -19,6 +19,9 @@ #include "net/quic/quic_client_session.h" #include "net/quic/quic_connection.h" #include "net/quic/quic_connection_helper.h" +#include "net/quic/quic_http_utils.h" +#include "net/quic/quic_reliable_client_stream.h" +#include "net/quic/spdy_utils.h" #include "net/quic/test_tools/mock_clock.h" #include "net/quic/test_tools/mock_crypto_client_stream_factory.h" #include "net/quic/test_tools/mock_random.h" @@ -30,10 +33,13 @@ #include "net/spdy/spdy_framer.h" #include "net/spdy/spdy_http_utils.h" #include "net/spdy/spdy_protocol.h" +#include "net/spdy/write_blocked_list.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" using testing::_; +using testing::AnyNumber; +using testing::Return; namespace net { namespace test { @@ -99,6 +105,14 @@ class AutoClosingStream : public QuicHttpStream { } // namespace +class QuicHttpStreamPeer { + public: + static QuicReliableClientStream* GetQuicReliableClientStream( + QuicHttpStream* stream) { + return stream->stream_; + } +}; + class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { protected: const static bool kFin = true; @@ -124,12 +138,16 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); peer_addr_ = IPEndPoint(ip, 443); self_addr_ = IPEndPoint(ip, 8435); + // TODO(rch): remove this. + QuicConnection::g_acks_do_not_instigate_acks = true; } ~QuicHttpStreamTest() { for (size_t i = 0; i < writes_.size(); i++) { delete writes_[i].packet; } + // TODO(rch): remove this. + QuicConnection::g_acks_do_not_instigate_acks = false; } // Adds a packet to the list of expected writes. @@ -168,10 +186,17 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { runner_ = new TestTaskRunner(&clock_); send_algorithm_ = new MockSendAlgorithm(); receive_algorithm_ = new TestReceiveAlgorithm(NULL); + EXPECT_CALL(*receive_algorithm_, RecordIncomingPacket(_, _, _, _)). + Times(AnyNumber()); + EXPECT_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)).Times(AnyNumber()); EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillRepeatedly( - testing::Return(QuicTime::Delta::Zero())); + Return(QuicTime::Delta::Zero())); EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, _, _, _)). - WillRepeatedly(testing::Return(QuicTime::Delta::Zero())); + WillRepeatedly(Return(QuicTime::Delta::Zero())); + EXPECT_CALL(*send_algorithm_, SmoothedRtt()).WillRepeatedly( + Return(QuicTime::Delta::Zero())); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillRepeatedly( + Return(QuicBandwidth::Zero())); helper_ = new QuicConnectionHelper(runner_.get(), &clock_, &random_generator_, socket); connection_ = new TestQuicConnection(guid_, peer_addr_, helper_); @@ -179,10 +204,12 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { connection_->SetSendAlgorithm(send_algorithm_); connection_->SetReceiveAlgorithm(receive_algorithm_); crypto_config_.SetDefaults(); - session_.reset(new QuicClientSession(connection_, socket, NULL, - &crypto_client_stream_factory_, - "www.google.com", DefaultQuicConfig(), - &crypto_config_, NULL)); + session_.reset( + new QuicClientSession(connection_, + scoped_ptr<DatagramClientSocket>(socket), NULL, + &crypto_client_stream_factory_, + "www.google.com", DefaultQuicConfig(), + &crypto_config_, NULL)); session_->GetCryptoStream()->CryptoConnect(); EXPECT_TRUE(session_->IsCryptoHandshakeConfirmed()); stream_.reset(use_closing_stream_ ? @@ -190,14 +217,16 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { new QuicHttpStream(session_->GetWeakPtr())); } - void SetRequestString(const std::string& method, const std::string& path) { + void SetRequestString(const std::string& method, + const std::string& path, + RequestPriority priority) { SpdyHeaderBlock headers; headers[":method"] = method; headers[":host"] = "www.google.com"; headers[":path"] = path; headers[":scheme"] = "http"; headers[":version"] = "HTTP/1.1"; - request_data_ = SerializeHeaderBlock(headers); + request_data_ = SerializeHeaderBlock(headers, true, priority); } void SetResponseString(const std::string& status, const std::string& body) { @@ -205,15 +234,22 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { headers[":status"] = status; headers[":version"] = "HTTP/1.1"; headers["content-type"] = "text/plain"; - response_data_ = SerializeHeaderBlock(headers) + body; + response_data_ = SerializeHeaderBlock(headers, false, DEFAULT_PRIORITY) + + body; } - std::string SerializeHeaderBlock(const SpdyHeaderBlock& headers) { + std::string SerializeHeaderBlock(const SpdyHeaderBlock& headers, + bool write_priority, + RequestPriority priority) { QuicSpdyCompressor compressor; + if (framer_.version() >= QUIC_VERSION_9 && write_priority) { + return compressor.CompressHeadersWithPriority( + ConvertRequestPriorityToQuicPriority(priority), headers); + } return compressor.CompressHeaders(headers); } - // Returns a newly created packet to send kData on stream 1. + // Returns a newly created packet to send kData on stream 3. QuicEncryptedPacket* ConstructDataPacket( QuicPacketSequenceNumber sequence_number, bool should_include_version, @@ -225,6 +261,14 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { return ConstructPacket(header_, QuicFrame(&frame)); } + // Returns a newly created packet to RST_STREAM stream 3. + QuicEncryptedPacket* ConstructRstStreamPacket( + QuicPacketSequenceNumber sequence_number) { + InitializeHeader(sequence_number, false); + QuicRstStreamFrame frame(3, QUIC_ERROR_PROCESSING_STREAM); + return ConstructPacket(header_, QuicFrame(&frame)); + } + // Returns a newly created packet to send ack data. QuicEncryptedPacket* ConstructAckPacket( QuicPacketSequenceNumber sequence_number, @@ -277,6 +321,7 @@ class QuicHttpStreamTest : public ::testing::TestWithParam<bool> { header_.public_header.guid = guid_; header_.public_header.reset_flag = false; header_.public_header.version_flag = should_include_version; + header_.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header_.packet_sequence_number = sequence_number; header_.fec_group = 0; header_.entropy_flag = false; @@ -321,7 +366,7 @@ TEST_F(QuicHttpStreamTest, IsConnectionReusable) { } TEST_F(QuicHttpStreamTest, GetRequest) { - SetRequestString("GET", "/"); + SetRequestString("GET", "/", DEFAULT_PRIORITY); AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, kFin, 0, request_data_)); Initialize(); @@ -362,8 +407,56 @@ TEST_F(QuicHttpStreamTest, GetRequest) { EXPECT_TRUE(AtEof()); } +// Regression test for http://crbug.com/288128 +TEST_F(QuicHttpStreamTest, GetRequestLargeResponse) { + SetRequestString("GET", "/", DEFAULT_PRIORITY); + AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, kFin, 0, + request_data_)); + Initialize(); + + request_.method = "GET"; + request_.url = GURL("http://www.google.com/"); + + EXPECT_EQ(OK, stream_->InitializeStream(&request_, DEFAULT_PRIORITY, + net_log_, callback_.callback())); + EXPECT_EQ(OK, stream_->SendRequest(headers_, &response_, + callback_.callback())); + EXPECT_EQ(&response_, stream_->GetResponseInfo()); + + // Ack the request. + scoped_ptr<QuicEncryptedPacket> ack(ConstructAckPacket(1, 0, 0)); + ProcessPacket(*ack); + + EXPECT_EQ(ERR_IO_PENDING, + stream_->ReadResponseHeaders(callback_.callback())); + + SpdyHeaderBlock headers; + headers[":status"] = "200 OK"; + headers[":version"] = "HTTP/1.1"; + headers["content-type"] = "text/plain"; + headers["big6"] = std::string(10000, 'x'); // Lots of x's. + + std::string response = SpdyUtils::SerializeUncompressedHeaders(headers); + EXPECT_LT(4096u, response.length()); + stream_->OnDataReceived(response.data(), response.length()); + stream_->OnClose(QUIC_NO_ERROR); + + // Now that the headers have been processed, the callback will return. + EXPECT_EQ(OK, callback_.WaitForResult()); + ASSERT_TRUE(response_.headers.get()); + EXPECT_EQ(200, response_.headers->response_code()); + EXPECT_TRUE(response_.headers->HasHeaderValue("Content-Type", "text/plain")); + + // There is no body, so this should return immediately. + EXPECT_EQ(0, stream_->ReadResponseBody(read_buffer_.get(), + read_buffer_->size(), + callback_.callback())); + EXPECT_TRUE(stream_->IsResponseBodyComplete()); + EXPECT_TRUE(AtEof()); +} + TEST_F(QuicHttpStreamTest, GetRequestFullResponseInSinglePacket) { - SetRequestString("GET", "/"); + SetRequestString("GET", "/", DEFAULT_PRIORITY); AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, kFin, 0, request_data_)); Initialize(); @@ -405,7 +498,7 @@ TEST_F(QuicHttpStreamTest, GetRequestFullResponseInSinglePacket) { } TEST_F(QuicHttpStreamTest, SendPostRequest) { - SetRequestString("POST", "/"); + SetRequestString("POST", "/", DEFAULT_PRIORITY); AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, !kFin, 0, request_data_)); AddWrite(SYNCHRONOUS, ConstructDataPacket(2, true, kFin, request_data_.length(), @@ -462,7 +555,7 @@ TEST_F(QuicHttpStreamTest, SendPostRequest) { } TEST_F(QuicHttpStreamTest, SendChunkedPostRequest) { - SetRequestString("POST", "/"); + SetRequestString("POST", "/", DEFAULT_PRIORITY); size_t chunk_size = strlen(kUploadData); AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, !kFin, 0, request_data_)); AddWrite(SYNCHRONOUS, ConstructDataPacket(2, true, !kFin, @@ -524,8 +617,9 @@ TEST_F(QuicHttpStreamTest, SendChunkedPostRequest) { } TEST_F(QuicHttpStreamTest, DestroyedEarly) { - SetRequestString("GET", "/"); + SetRequestString("GET", "/", DEFAULT_PRIORITY); AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, kFin, 0, request_data_)); + AddWrite(SYNCHRONOUS, ConstructRstStreamPacket(2)); use_closing_stream_ = true; Initialize(); @@ -535,9 +629,54 @@ TEST_F(QuicHttpStreamTest, DestroyedEarly) { EXPECT_EQ(OK, stream_->InitializeStream(&request_, DEFAULT_PRIORITY, net_log_, callback_.callback())); EXPECT_EQ(OK, stream_->SendRequest(headers_, &response_, - callback_.callback())); + callback_.callback())); + EXPECT_EQ(&response_, stream_->GetResponseInfo()); + + // Ack the request. + scoped_ptr<QuicEncryptedPacket> ack(ConstructAckPacket(1, 0, 0)); + ProcessPacket(*ack); + EXPECT_EQ(ERR_IO_PENDING, + stream_->ReadResponseHeaders(callback_.callback())); + + // Send the response with a body. + SetResponseString("404 OK", "hello world!"); + scoped_ptr<QuicEncryptedPacket> resp( + ConstructDataPacket(2, false, kFin, 0, response_data_)); + + // In the course of processing this packet, the QuicHttpStream close itself. + ProcessPacket(*resp); + + EXPECT_TRUE(AtEof()); +} + +TEST_F(QuicHttpStreamTest, Priority) { + SetRequestString("GET", "/", MEDIUM); + AddWrite(SYNCHRONOUS, ConstructDataPacket(1, true, kFin, 0, request_data_)); + AddWrite(SYNCHRONOUS, ConstructRstStreamPacket(2)); + use_closing_stream_ = true; + Initialize(); + + request_.method = "GET"; + request_.url = GURL("http://www.google.com/"); + + EXPECT_EQ(OK, stream_->InitializeStream(&request_, MEDIUM, + net_log_, callback_.callback())); + + // Check that priority is highest. + QuicReliableClientStream* reliable_stream = + QuicHttpStreamPeer::GetQuicReliableClientStream(stream_.get()); + DCHECK(reliable_stream); + DCHECK_EQ(static_cast<QuicPriority>(kHighestPriority), + reliable_stream->EffectivePriority()); + + EXPECT_EQ(OK, stream_->SendRequest(headers_, &response_, + callback_.callback())); EXPECT_EQ(&response_, stream_->GetResponseInfo()); + // Check that priority has now dropped back to MEDIUM. + DCHECK_EQ(MEDIUM, ConvertQuicPriorityToRequestPriority( + reliable_stream->EffectivePriority())); + // Ack the request. scoped_ptr<QuicEncryptedPacket> ack(ConstructAckPacket(1, 0, 0)); ProcessPacket(*ack); @@ -555,6 +694,35 @@ TEST_F(QuicHttpStreamTest, DestroyedEarly) { EXPECT_TRUE(AtEof()); } +// Regression test for http://crbug.com/294870 +TEST_F(QuicHttpStreamTest, CheckPriorityWithNoDelegate) { + SetRequestString("GET", "/", MEDIUM); + use_closing_stream_ = true; + Initialize(); + + request_.method = "GET"; + request_.url = GURL("http://www.google.com/"); + + EXPECT_EQ(OK, stream_->InitializeStream(&request_, MEDIUM, + net_log_, callback_.callback())); + + // Check that priority is highest. + QuicReliableClientStream* reliable_stream = + QuicHttpStreamPeer::GetQuicReliableClientStream(stream_.get()); + DCHECK(reliable_stream); + QuicReliableClientStream::Delegate* delegate = reliable_stream->GetDelegate(); + DCHECK(delegate); + DCHECK_EQ(static_cast<QuicPriority>(kHighestPriority), + reliable_stream->EffectivePriority()); + + // Set Delegate to NULL and make sure EffectivePriority returns highest + // priority. + reliable_stream->SetDelegate(NULL); + DCHECK_EQ(static_cast<QuicPriority>(kHighestPriority), + reliable_stream->EffectivePriority()); + reliable_stream->SetDelegate(delegate); +} + } // namespace test } // namespace net diff --git a/chromium/net/quic/quic_http_utils.cc b/chromium/net/quic/quic_http_utils.cc new file mode 100644 index 00000000000..4a486268547 --- /dev/null +++ b/chromium/net/quic/quic_http_utils.cc @@ -0,0 +1,23 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/quic/quic_http_utils.h" + +namespace net { + +QuicPriority ConvertRequestPriorityToQuicPriority( + const RequestPriority priority) { + DCHECK_GE(priority, MINIMUM_PRIORITY); + DCHECK_LT(priority, NUM_PRIORITIES); + return static_cast<QuicPriority>(HIGHEST - priority); +} + +NET_EXPORT_PRIVATE RequestPriority ConvertQuicPriorityToRequestPriority( + QuicPriority priority) { + // Handle invalid values gracefully. + return (priority >= 5) ? + IDLE : static_cast<RequestPriority>(HIGHEST - priority); +} + +} // namespace net diff --git a/chromium/net/quic/quic_http_utils.h b/chromium/net/quic/quic_http_utils.h new file mode 100644 index 00000000000..c7e031ae605 --- /dev/null +++ b/chromium/net/quic/quic_http_utils.h @@ -0,0 +1,22 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_QUIC_QUIC_HTTP_UTILS_H_ +#define NET_QUIC_QUIC_HTTP_UTILS_H_ + +#include "net/base/net_export.h" +#include "net/base/request_priority.h" +#include "net/quic/quic_protocol.h" + +namespace net { + +NET_EXPORT_PRIVATE QuicPriority ConvertRequestPriorityToQuicPriority( + RequestPriority priority); + +NET_EXPORT_PRIVATE RequestPriority ConvertQuicPriorityToRequestPriority( + QuicPriority priority); + +} // namespace net + +#endif // NET_QUIC_QUIC_HTTP_UTILS_H_ diff --git a/chromium/net/quic/quic_http_utils_test.cc b/chromium/net/quic/quic_http_utils_test.cc new file mode 100644 index 00000000000..93b62e2e9d7 --- /dev/null +++ b/chromium/net/quic/quic_http_utils_test.cc @@ -0,0 +1,35 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/quic/quic_http_utils.h" + +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace test { + +TEST(QuicHttpUtilsTest, ConvertRequestPriorityToQuicPriority) { + EXPECT_EQ(0u, ConvertRequestPriorityToQuicPriority(HIGHEST)); + EXPECT_EQ(1u, ConvertRequestPriorityToQuicPriority(MEDIUM)); + EXPECT_EQ(2u, ConvertRequestPriorityToQuicPriority(LOW)); + EXPECT_EQ(3u, ConvertRequestPriorityToQuicPriority(LOWEST)); + EXPECT_EQ(4u, ConvertRequestPriorityToQuicPriority(IDLE)); +} + +TEST(QuicHttpUtilsTest, ConvertQuicPriorityToRequestPriority) { + EXPECT_EQ(HIGHEST, ConvertQuicPriorityToRequestPriority(0)); + EXPECT_EQ(MEDIUM, ConvertQuicPriorityToRequestPriority(1)); + EXPECT_EQ(LOW, ConvertQuicPriorityToRequestPriority(2)); + EXPECT_EQ(LOWEST, ConvertQuicPriorityToRequestPriority(3)); + EXPECT_EQ(IDLE, ConvertQuicPriorityToRequestPriority(4)); + // These are invalid values, but we should still handle them + // gracefully. TODO(rtenneti): should we test for all possible values of + // uint32? + for (int i = 5; i < kuint8max; ++i) { + EXPECT_EQ(IDLE, ConvertQuicPriorityToRequestPriority(i)); + } +} + +} // namespace test +} // namespace net diff --git a/chromium/net/quic/quic_network_transaction_unittest.cc b/chromium/net/quic/quic_network_transaction_unittest.cc index 0722b58121f..a6cbff1b81a 100644 --- a/chromium/net/quic/quic_network_transaction_unittest.cc +++ b/chromium/net/quic/quic_network_transaction_unittest.cc @@ -25,6 +25,7 @@ #include "net/quic/crypto/quic_decrypter.h" #include "net/quic/crypto/quic_encrypter.h" #include "net/quic/quic_framer.h" +#include "net/quic/quic_http_utils.h" #include "net/quic/test_tools/crypto_test_utils.h" #include "net/quic/test_tools/mock_clock.h" #include "net/quic/test_tools/mock_crypto_client_stream_factory.h" @@ -92,6 +93,7 @@ class QuicNetworkTransactionTest : public PlatformTest { header.public_header.guid = 0xDEADBEEF; header.public_header.reset_flag = false; header.public_header.version_flag = false; + header.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header.packet_sequence_number = num; header.entropy_flag = false; header.fec_flag = false; @@ -108,6 +110,7 @@ class QuicNetworkTransactionTest : public PlatformTest { header.public_header.guid = 0xDEADBEEF; header.public_header.reset_flag = false; header.public_header.version_flag = false; + header.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header.packet_sequence_number = num; header.entropy_flag = false; header.fec_flag = false; @@ -129,6 +132,7 @@ class QuicNetworkTransactionTest : public PlatformTest { header.public_header.guid = 0xDEADBEEF; header.public_header.reset_flag = false; header.public_header.version_flag = false; + header.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header.packet_sequence_number = 2; header.entropy_flag = false; header.fec_flag = false; @@ -174,6 +178,10 @@ class QuicNetworkTransactionTest : public PlatformTest { std::string SerializeHeaderBlock(const SpdyHeaderBlock& headers) { QuicSpdyCompressor compressor; + if (QuicVersionMax() >= QUIC_VERSION_9) { + return compressor.CompressHeadersWithPriority( + ConvertRequestPriorityToQuicPriority(DEFAULT_PRIORITY), headers); + } return compressor.CompressHeaders(headers); } @@ -207,6 +215,7 @@ class QuicNetworkTransactionTest : public PlatformTest { header_.public_header.guid = random_generator_.RandUint64(); header_.public_header.reset_flag = false; header_.public_header.version_flag = should_include_version; + header_.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header_.packet_sequence_number = sequence_number; header_.fec_group = 0; header_.entropy_flag = false; @@ -232,6 +241,7 @@ class QuicNetworkTransactionTest : public PlatformTest { params_.http_server_properties = http_server_properties.GetWeakPtr(); session_ = new HttpNetworkSession(params_); + session_->quic_stream_factory()->set_require_confirmation(false); } void CheckWasQuicResponse(const scoped_ptr<HttpNetworkTransaction>& trans) { @@ -679,7 +689,11 @@ TEST_F(QuicNetworkTransactionTest, ZeroRTTWithNoHttpRace) { host_resolver_.rules()->AddIPLiteralRule("www.google.com", "192.168.0.1", ""); HostResolver::RequestInfo info(HostPortPair("www.google.com", 80)); AddressList address; - host_resolver_.Resolve(info, &address, CompletionCallback(), NULL, + host_resolver_.Resolve(info, + DEFAULT_PRIORITY, + &address, + CompletionCallback(), + NULL, net_log_.bound()); CreateSession(); @@ -687,6 +701,65 @@ TEST_F(QuicNetworkTransactionTest, ZeroRTTWithNoHttpRace) { SendRequestAndExpectQuicResponse("hello!"); } +TEST_F(QuicNetworkTransactionTest, ZeroRTTWithConfirmationRequired) { + HttpStreamFactory::EnableNpnSpdy(); // Enables QUIC too. + + scoped_ptr<QuicEncryptedPacket> req( + ConstructDataPacket(1, 3, true, true, 0, + GetRequestString("GET", "http", "/"))); + scoped_ptr<QuicEncryptedPacket> ack(ConstructAckPacket(1, 0)); + + MockWrite quic_writes[] = { + MockWrite(SYNCHRONOUS, req->data(), req->length()), + MockWrite(SYNCHRONOUS, ack->data(), ack->length()), + }; + + scoped_ptr<QuicEncryptedPacket> resp( + ConstructDataPacket( + 1, 3, false, true, 0, GetResponseString("200 OK", "hello!"))); + MockRead quic_reads[] = { + MockRead(SYNCHRONOUS, resp->data(), resp->length()), + MockRead(ASYNC, OK), // EOF + }; + + DelayedSocketData quic_data( + 1, // wait for one write to finish before reading. + quic_reads, arraysize(quic_reads), + quic_writes, arraysize(quic_writes)); + + socket_factory_.AddSocketDataProvider(&quic_data); + + // The non-alternate protocol job needs to hang in order to guarantee that + // the alternate-protocol job will "win". + AddHangingNonAlternateProtocolSocketData(); + + // In order for a new QUIC session to be established via alternate-protocol + // without racing an HTTP connection, we need the host resolution to happen + // synchronously. Of course, even though QUIC *could* perform a 0-RTT + // connection to the the server, in this test we require confirmation + // before encrypting so the HTTP job will still start. + host_resolver_.set_synchronous_mode(true); + host_resolver_.rules()->AddIPLiteralRule("www.google.com", "192.168.0.1", ""); + HostResolver::RequestInfo info(HostPortPair("www.google.com", 80)); + AddressList address; + host_resolver_.Resolve(info, DEFAULT_PRIORITY, &address, + CompletionCallback(), NULL, net_log_.bound()); + + CreateSession(); + session_->quic_stream_factory()->set_require_confirmation(true); + AddQuicAlternateProtocolMapping(MockCryptoClientStream::ZERO_RTT); + + scoped_ptr<HttpNetworkTransaction> trans( + new HttpNetworkTransaction(DEFAULT_PRIORITY, session_.get())); + TestCompletionCallback callback; + int rv = trans->Start(&request_, callback.callback(), net_log_.bound()); + EXPECT_EQ(ERR_IO_PENDING, rv); + + crypto_client_stream_factory_.last_stream()->SendOnCryptoHandshakeEvent( + QuicSession::HANDSHAKE_CONFIRMED); + EXPECT_EQ(OK, callback.WaitForResult()); +} + TEST_F(QuicNetworkTransactionTest, BrokenAlternateProtocol) { HttpStreamFactory::EnableNpnSpdy(); // Enables QUIC too. @@ -748,5 +821,32 @@ TEST_F(QuicNetworkTransactionTest, BrokenAlternateProtocolReadError) { ExpectBrokenAlternateProtocolMapping(); } +TEST_F(QuicNetworkTransactionTest, FailedZeroRttBrokenAlternateProtocol) { + HttpStreamFactory::EnableNpnSpdy(); // Enables QUIC too. + + // Alternate-protocol job + MockRead quic_reads[] = { + MockRead(ASYNC, ERR_SOCKET_NOT_CONNECTED), + }; + StaticSocketDataProvider quic_data(quic_reads, arraysize(quic_reads), + NULL, 0); + socket_factory_.AddSocketDataProvider(&quic_data); + + AddHangingNonAlternateProtocolSocketData(); + + CreateSession(); + + AddQuicAlternateProtocolMapping(MockCryptoClientStream::ZERO_RTT); + + scoped_ptr<HttpNetworkTransaction> trans( + new HttpNetworkTransaction(DEFAULT_PRIORITY, session_.get())); + TestCompletionCallback callback; + int rv = trans->Start(&request_, callback.callback(), net_log_.bound()); + EXPECT_EQ(ERR_IO_PENDING, rv); + EXPECT_EQ(ERR_CONNECTION_CLOSED, callback.WaitForResult()); + + ExpectBrokenAlternateProtocolMapping(); +} + } // namespace test } // namespace net diff --git a/chromium/net/quic/quic_packet_creator.cc b/chromium/net/quic/quic_packet_creator.cc index 6d696768612..609ebcbd734 100644 --- a/chromium/net/quic/quic_packet_creator.cc +++ b/chromium/net/quic/quic_packet_creator.cc @@ -6,11 +6,13 @@ #include "base/logging.h" #include "net/quic/crypto/quic_random.h" +#include "net/quic/quic_ack_notifier.h" #include "net/quic/quic_fec_group.h" #include "net/quic/quic_utils.h" using base::StringPiece; using std::make_pair; +using std::max; using std::min; using std::pair; using std::vector; @@ -28,10 +30,8 @@ QuicPacketCreator::QuicPacketCreator(QuicGuid guid, fec_group_number_(0), is_server_(is_server), send_version_in_packet_(!is_server), - packet_size_(GetPacketHeaderSize(options_.send_guid_length, - send_version_in_packet_, - options_.send_sequence_number_length, - NOT_IN_FEC_GROUP)) { + sequence_number_length_(options_.send_sequence_number_length), + packet_size_(0) { framer_->set_fec_builder(this); } @@ -57,11 +57,6 @@ void QuicPacketCreator::MaybeStartFEC() { // Set the fec group number to the sequence number of the next packet. fec_group_number_ = sequence_number() + 1; fec_group_.reset(new QuicFecGroup()); - packet_size_ = GetPacketHeaderSize(options_.send_guid_length, - send_version_in_packet_, - options_.send_sequence_number_length, - IN_FEC_GROUP); - DCHECK_LE(packet_size_, options_.max_packet_length); } } @@ -77,6 +72,30 @@ void QuicPacketCreator::StopSendingVersion() { } } +void QuicPacketCreator::UpdateSequenceNumberLength( + QuicPacketSequenceNumber least_packet_awaited_by_peer, + QuicByteCount bytes_per_second) { + DCHECK_LE(least_packet_awaited_by_peer, sequence_number_ + 1); + // Since the packet creator will not change sequence number length mid FEC + // group, include the size of an FEC group to be safe. + const QuicPacketSequenceNumber current_delta = + options_.max_packets_per_fec_group + sequence_number_ + 1 + - least_packet_awaited_by_peer; + const uint64 congestion_window = + bytes_per_second / options_.max_packet_length; + const uint64 delta = max(current_delta, congestion_window); + + if (delta < 1 << ((PACKET_1BYTE_SEQUENCE_NUMBER * 8) - 2)) { + options_.send_sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; + } else if (delta < 1 << ((PACKET_2BYTE_SEQUENCE_NUMBER * 8) - 2)) { + options_.send_sequence_number_length = PACKET_2BYTE_SEQUENCE_NUMBER; + } else if (delta < 1 << ((PACKET_4BYTE_SEQUENCE_NUMBER * 8) - 2)) { + options_.send_sequence_number_length = PACKET_4BYTE_SEQUENCE_NUMBER; + } else { + options_.send_sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER; + } +} + bool QuicPacketCreator::HasRoomForStreamFrame(QuicStreamId id, QuicStreamOffset offset) const { return BytesFree() > @@ -105,7 +124,12 @@ size_t QuicPacketCreator::CreateStreamFrame(QuicStreamId id, StreamFramePacketOverhead( framer_->version(), PACKET_8BYTE_GUID, kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, IN_FEC_GROUP)); - DCHECK(HasRoomForStreamFrame(id, offset)); + if (!HasRoomForStreamFrame(id, offset)) { + LOG(DFATAL) << "No room for Stream frame, BytesFree: " << BytesFree() + << " MinStreamFrameSize: " + << QuicFramer::GetMinStreamFrameSize( + framer_->version(), id, offset, true); + } const size_t free_bytes = BytesFree(); size_t bytes_consumed = 0; @@ -145,12 +169,50 @@ size_t QuicPacketCreator::CreateStreamFrame(QuicStreamId id, return bytes_consumed; } +size_t QuicPacketCreator::CreateStreamFrameWithNotifier( + QuicStreamId id, + StringPiece data, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier* notifier, + QuicFrame* frame) { + size_t bytes_consumed = CreateStreamFrame(id, data, offset, fin, frame); + + // The frame keeps track of the QuicAckNotifier until it is serialized into + // a packet. At that point the notifier is informed of the sequence number + // of the packet that this frame was eventually sent in. + frame->stream_frame->notifier = notifier; + + return bytes_consumed; +} + +SerializedPacket QuicPacketCreator::ReserializeAllFrames( + const QuicFrames& frames, + QuicSequenceNumberLength original_length) { + // Temporarily set the sequence number length and disable FEC. + const QuicSequenceNumberLength start_length = sequence_number_length_; + const QuicSequenceNumberLength start_options_length = + options_.send_sequence_number_length; + const QuicFecGroupNumber start_fec_group = fec_group_number_; + sequence_number_length_ = original_length; + options_.send_sequence_number_length = original_length; + fec_group_number_ = 0; + SerializedPacket serialized_packet = SerializeAllFrames(frames); + sequence_number_length_ = start_length; + options_.send_sequence_number_length = start_options_length; + fec_group_number_ = start_fec_group; + return serialized_packet; +} + SerializedPacket QuicPacketCreator::SerializeAllFrames( const QuicFrames& frames) { // TODO(satyamshekhar): Verify that this DCHECK won't fail. What about queued // frames from SendStreamData()[send_stream_should_flush_ == false && // data.empty() == true] and retransmit due to RTO. DCHECK_EQ(0u, queued_frames_.size()); + if (frames.empty()) { + LOG(DFATAL) << "Attempt to serialize empty packet"; + } for (size_t i = 0; i < frames.size(); ++i) { bool success = AddFrame(frames[i], false); DCHECK(success); @@ -167,10 +229,27 @@ bool QuicPacketCreator::HasPendingFrames() { size_t QuicPacketCreator::BytesFree() const { const size_t max_plaintext_size = framer_->GetMaxPlaintextSize(options_.max_packet_length); - if (packet_size_ > max_plaintext_size) { + if (PacketSize() >= max_plaintext_size) { return 0; } - return max_plaintext_size - packet_size_; + return max_plaintext_size - PacketSize(); +} + +size_t QuicPacketCreator::PacketSize() const { + if (queued_frames_.empty()) { + // Only adjust the sequence number length when the FEC group is not open, + // to ensure no packets in a group are too large. + if (fec_group_.get() == NULL || + fec_group_->NumReceivedPackets() == 0) { + sequence_number_length_ = options_.send_sequence_number_length; + } + packet_size_ = GetPacketHeaderSize(options_.send_guid_length, + send_version_in_packet_, + sequence_number_length_, + fec_group_number_ == 0 ? + NOT_IN_FEC_GROUP : IN_FEC_GROUP); + } + return packet_size_; } bool QuicPacketCreator::AddSavedFrame(const QuicFrame& frame) { @@ -178,18 +257,27 @@ bool QuicPacketCreator::AddSavedFrame(const QuicFrame& frame) { } SerializedPacket QuicPacketCreator::SerializePacket() { - DCHECK_EQ(false, queued_frames_.empty()); + if (queued_frames_.empty()) { + LOG(DFATAL) << "Attempt to serialize empty packet"; + } QuicPacketHeader header; FillPacketHeader(fec_group_number_, false, false, &header); - SerializedPacket serialized = framer_->BuildDataPacket( - header, queued_frames_, packet_size_); + SerializedPacket serialized = + framer_->BuildDataPacket(header, queued_frames_, packet_size_); + + // Run through all the included frames and if any of them have an AckNotifier + // registered, then inform the AckNotifier that it should be interested in + // this packet's sequence number. + for (QuicFrames::iterator it = queued_frames_.begin(); + it != queued_frames_.end(); ++it) { + if (it->type == STREAM_FRAME && it->stream_frame->notifier != NULL) { + it->stream_frame->notifier->AddSequenceNumber(serialized.sequence_number); + } + } + + packet_size_ = 0; queued_frames_.clear(); - packet_size_ = GetPacketHeaderSize(options_.send_guid_length, - send_version_in_packet_, - options_.send_sequence_number_length, - fec_group_.get() != NULL ? - IN_FEC_GROUP : NOT_IN_FEC_GROUP); serialized.retransmittable_frames = queued_retransmittable_frames_.release(); return serialized; } @@ -206,11 +294,7 @@ SerializedPacket QuicPacketCreator::SerializeFec() { SerializedPacket serialized = framer_->BuildFecPacket(header, fec_data); fec_group_.reset(NULL); fec_group_number_ = 0; - // Reset packet_size_, since the next packet may not have an FEC group. - packet_size_ = GetPacketHeaderSize(options_.send_guid_length, - send_version_in_packet_, - options_.send_sequence_number_length, - NOT_IN_FEC_GROUP); + packet_size_ = 0; DCHECK(serialized.packet); DCHECK_GE(options_.max_packet_length, serialized.packet->length()); return serialized; @@ -247,6 +331,7 @@ void QuicPacketCreator::FillPacketHeader(QuicFecGroupNumber fec_group, header->public_header.version_flag = send_version_in_packet_; header->fec_flag = fec_flag; header->packet_sequence_number = ++sequence_number_; + header->public_header.sequence_number_length = sequence_number_length_; bool entropy_flag; if (header->packet_sequence_number == 1) { @@ -278,6 +363,7 @@ bool QuicPacketCreator::AddFrame(const QuicFrame& frame, if (frame_len == 0) { return false; } + DCHECK_LT(0u, packet_size_); packet_size_ += frame_len; if (save_retransmittable_frames && ShouldRetransmit(frame)) { diff --git a/chromium/net/quic/quic_packet_creator.h b/chromium/net/quic/quic_packet_creator.h index a1d74fa532f..68b62166fee 100644 --- a/chromium/net/quic/quic_packet_creator.h +++ b/chromium/net/quic/quic_packet_creator.h @@ -23,6 +23,7 @@ namespace test { class QuicPacketCreatorPeer; } +class QuicAckNotifier; class QuicRandom; class NET_EXPORT_PRIVATE QuicPacketCreator : public QuicFecBuilderInterface { @@ -34,7 +35,7 @@ class NET_EXPORT_PRIVATE QuicPacketCreator : public QuicFecBuilderInterface { random_reorder(false), max_packets_per_fec_group(0), send_guid_length(PACKET_8BYTE_GUID), - send_sequence_number_length(PACKET_6BYTE_SEQUENCE_NUMBER) { + send_sequence_number_length(PACKET_1BYTE_SEQUENCE_NUMBER) { } size_t max_packet_length; @@ -69,6 +70,12 @@ class NET_EXPORT_PRIVATE QuicPacketCreator : public QuicFecBuilderInterface { // Makes the framer not serialize the protocol version in sent packets. void StopSendingVersion(); + // Update the sequence number length to use in future packets as soon as it + // can be safely changed. + void UpdateSequenceNumberLength( + QuicPacketSequenceNumber least_packet_awaited_by_peer, + QuicByteCount bytes_per_second); + // The overhead the framing will add for a packet with one frame. static size_t StreamFramePacketOverhead( QuicVersion version, @@ -89,18 +96,37 @@ class NET_EXPORT_PRIVATE QuicPacketCreator : public QuicFecBuilderInterface { bool fin, QuicFrame* frame); + // As above, but keeps track of an QuicAckNotifier that should be called when + // the packet that contains this stream frame is ACKed. + // The |notifier| is not owned by the QuicPacketGenerator and must outlive the + // generated packet. + size_t CreateStreamFrameWithNotifier(QuicStreamId id, + base::StringPiece data, + QuicStreamOffset offset, + bool fin, + QuicAckNotifier* notifier, + QuicFrame* frame); + // Serializes all frames into a single packet. All frames must fit into a // single packet. Also, sets the entropy hash of the serialized packet to a // random bool and returns that value as a member of SerializedPacket. // Never returns a RetransmittableFrames in SerializedPacket. SerializedPacket SerializeAllFrames(const QuicFrames& frames); + // Re-serializes frames with the original packet's sequence number length. + // Used for retransmitting packets to ensure they aren't too long. + SerializedPacket ReserializeAllFrames( + const QuicFrames& frames, QuicSequenceNumberLength original_length); + // Returns true if there are frames pending to be serialized. bool HasPendingFrames(); // Returns the number of bytes which are free to frames in the current packet. size_t BytesFree() const; + // Returns the number of bytes in the current packet, including the header. + size_t PacketSize() const; + // Adds |frame| to the packet creator's list of frames to be serialized. // Returns false if the frame doesn't fit into the current packet. bool AddSavedFrame(const QuicFrame& frame); @@ -132,6 +158,8 @@ class NET_EXPORT_PRIVATE QuicPacketCreator : public QuicFecBuilderInterface { QuicEncryptedPacket* SerializeVersionNegotiationPacket( const QuicVersionVector& supported_versions); + // Sequence number of the last created packet, or 0 if no packets have been + // created. QuicPacketSequenceNumber sequence_number() const { return sequence_number_; } @@ -170,7 +198,13 @@ class NET_EXPORT_PRIVATE QuicPacketCreator : public QuicFecBuilderInterface { // Controls whether protocol version should be included while serializing the // packet. bool send_version_in_packet_; - size_t packet_size_; + // The sequence number length for the current packet and the current FEC group + // if FEC is enabled. + // Mutable so PacketSize() can adjust it when the packet is empty. + mutable QuicSequenceNumberLength sequence_number_length_; + // packet_size_ is mutable because it's just a cache of the current size. + // packet_size should never be read directly, use PacketSize() instead. + mutable size_t packet_size_; QuicFrames queued_frames_; scoped_ptr<RetransmittableFrames> queued_retransmittable_frames_; diff --git a/chromium/net/quic/quic_packet_creator_test.cc b/chromium/net/quic/quic_packet_creator_test.cc index 1d793d23e18..51133b6880e 100644 --- a/chromium/net/quic/quic_packet_creator_test.cc +++ b/chromium/net/quic/quic_packet_creator_test.cc @@ -32,7 +32,6 @@ class QuicPacketCreatorTest : public ::testing::TestWithParam<bool> { QuicPacketCreatorTest() : server_framer_(QuicVersionMax(), QuicTime::Zero(), true), client_framer_(QuicVersionMax(), QuicTime::Zero(), false), - id_(1), sequence_number_(0), guid_(2), data_("foo"), @@ -64,7 +63,6 @@ class QuicPacketCreatorTest : public ::testing::TestWithParam<bool> { QuicFramer server_framer_; QuicFramer client_framer_; testing::StrictMock<MockFramerVisitor> framer_visitor_; - QuicStreamId id_; QuicPacketSequenceNumber sequence_number_; QuicGuid guid_; string data_; @@ -133,6 +131,120 @@ TEST_F(QuicPacketCreatorTest, SerializeWithFEC) { delete serialized.packet; } +TEST_F(QuicPacketCreatorTest, SerializeChangingSequenceNumberLength) { + frames_.push_back(QuicFrame(new QuicAckFrame(0u, QuicTime::Zero(), 0u))); + creator_.AddSavedFrame(frames_[0]); + creator_.options()->send_sequence_number_length = + PACKET_4BYTE_SEQUENCE_NUMBER; + SerializedPacket serialized = creator_.SerializePacket(); + // The sequence number length will not change mid-packet. + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, serialized.sequence_number_length); + + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnAckFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized.packet); + delete serialized.packet; + + creator_.AddSavedFrame(frames_[0]); + serialized = creator_.SerializePacket(); + // Now the actual sequence number length should have changed. + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, serialized.sequence_number_length); + delete frames_[0].ack_frame; + + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnAckFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized.packet); + delete serialized.packet; +} + +TEST_F(QuicPacketCreatorTest, SerializeWithFECChangingSequenceNumberLength) { + creator_.options()->max_packets_per_fec_group = 6; + ASSERT_FALSE(creator_.ShouldSendFec(false)); + creator_.MaybeStartFEC(); + + frames_.push_back(QuicFrame(new QuicAckFrame(0u, QuicTime::Zero(), 0u))); + creator_.AddSavedFrame(frames_[0]); + // Change the sequence number length mid-FEC group and it should not change. + creator_.options()->send_sequence_number_length = + PACKET_4BYTE_SEQUENCE_NUMBER; + SerializedPacket serialized = creator_.SerializePacket(); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, serialized.sequence_number_length); + + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnFecProtectedPayload(_)); + EXPECT_CALL(framer_visitor_, OnAckFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized.packet); + delete serialized.packet; + + ASSERT_FALSE(creator_.ShouldSendFec(false)); + ASSERT_TRUE(creator_.ShouldSendFec(true)); + + serialized = creator_.SerializeFec(); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, serialized.sequence_number_length); + ASSERT_EQ(2u, serialized.sequence_number); + + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnFecData(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized.packet); + delete serialized.packet; + + // Ensure the next FEC group starts using the new sequence number length. + serialized = creator_.SerializeAllFrames(frames_); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, serialized.sequence_number_length); + delete frames_[0].stream_frame; + delete serialized.packet; +} + +TEST_F(QuicPacketCreatorTest, ReserializeFramesWithSequenceNumberLength) { + // If the original packet sequence number length, the current sequence number + // length, and the configured send sequence number length are different, the + // retransmit must sent with the original length and the others do not change. + creator_.options()->send_sequence_number_length = + PACKET_4BYTE_SEQUENCE_NUMBER; + QuicPacketCreatorPeer::SetSequenceNumberLength(&creator_, + PACKET_2BYTE_SEQUENCE_NUMBER); + frames_.push_back(QuicFrame(new QuicStreamFrame( + 0u, false, 0u, StringPiece("")))); + SerializedPacket serialized = + creator_.ReserializeAllFrames(frames_, PACKET_1BYTE_SEQUENCE_NUMBER); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + QuicPacketCreatorPeer::GetSequenceNumberLength(&creator_)); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, serialized.sequence_number_length); + delete frames_[0].stream_frame; + + { + InSequence s; + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + } + ProcessPacket(serialized.packet); + delete serialized.packet; +} + TEST_F(QuicPacketCreatorTest, SerializeConnectionClose) { QuicConnectionCloseFrame frame; frame.error_code = QUIC_NO_ERROR; @@ -221,6 +333,53 @@ TEST_F(QuicPacketCreatorTest, SerializeVersionNegotiationPacket) { client_framer_.ProcessPacket(*encrypted.get()); } +TEST_F(QuicPacketCreatorTest, UpdatePacketSequenceNumberLengthLeastAwaiting) { + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.set_sequence_number(64); + creator_.UpdateSequenceNumberLength(2, 10000); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.set_sequence_number(64 * 256); + creator_.UpdateSequenceNumberLength(2, 10000); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.set_sequence_number(64 * 256 * 256); + creator_.UpdateSequenceNumberLength(2, 10000); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.set_sequence_number(GG_UINT64_C(64) * 256 * 256 * 256 * 256); + creator_.UpdateSequenceNumberLength(2, 10000); + EXPECT_EQ(PACKET_6BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); +} + +TEST_F(QuicPacketCreatorTest, UpdatePacketSequenceNumberLengthBandwidth) { + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.UpdateSequenceNumberLength(1, 10000); + EXPECT_EQ(PACKET_1BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.UpdateSequenceNumberLength(1, 10000 * 256); + EXPECT_EQ(PACKET_2BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.UpdateSequenceNumberLength(1, 10000 * 256 * 256); + EXPECT_EQ(PACKET_4BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); + + creator_.UpdateSequenceNumberLength( + 1, GG_UINT64_C(1000) * 256 * 256 * 256 * 256); + EXPECT_EQ(PACKET_6BYTE_SEQUENCE_NUMBER, + creator_.options()->send_sequence_number_length); +} + INSTANTIATE_TEST_CASE_P(ToggleVersionSerialization, QuicPacketCreatorTest, ::testing::Values(false, true)); @@ -257,7 +416,7 @@ TEST_P(QuicPacketCreatorTest, CreateStreamFrameTooLarge) { creator_.options()->max_packet_length = GetPacketLengthForOneStream( client_framer_.version(), QuicPacketCreatorPeer::SendVersionInPacket(&creator_), - NOT_IN_FEC_GROUP, &payload_length); + PACKET_1BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP, &payload_length); QuicFrame frame; const string too_long_payload(payload_length * 2, 'a'); size_t consumed = creator_.CreateStreamFrame( @@ -279,8 +438,7 @@ TEST_P(QuicPacketCreatorTest, AddFrameAndSerialize) { GetPacketHeaderSize( creator_.options()->send_guid_length, QuicPacketCreatorPeer::SendVersionInPacket(&creator_), - PACKET_6BYTE_SEQUENCE_NUMBER, - NOT_IN_FEC_GROUP), + PACKET_1BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), creator_.BytesFree()); // Add a variety of frame types and then a padding frame. @@ -323,7 +481,7 @@ TEST_P(QuicPacketCreatorTest, AddFrameAndSerialize) { GetPacketHeaderSize( creator_.options()->send_guid_length, QuicPacketCreatorPeer::SendVersionInPacket(&creator_), - PACKET_6BYTE_SEQUENCE_NUMBER, + PACKET_1BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP), creator_.BytesFree()); } diff --git a/chromium/net/quic/quic_packet_generator.cc b/chromium/net/quic/quic_packet_generator.cc index 7600010a067..9151cb883b3 100644 --- a/chromium/net/quic/quic_packet_generator.cc +++ b/chromium/net/quic/quic_packet_generator.cc @@ -12,13 +12,15 @@ using base::StringPiece; namespace net { +class QuicAckNotifier; + QuicPacketGenerator::QuicPacketGenerator(DelegateInterface* delegate, DebugDelegateInterface* debug_delegate, QuicPacketCreator* creator) : delegate_(delegate), debug_delegate_(debug_delegate), packet_creator_(creator), - should_flush_(true), + batch_mode_(false), should_send_ack_(false), should_send_feedback_(false) { } @@ -60,7 +62,6 @@ void QuicPacketGenerator::SetShouldSendAck(bool also_send_feedback) { SendQueuedFrames(); } - void QuicPacketGenerator::AddControlFrame(const QuicFrame& frame) { queued_control_frames_.push_back(frame); SendQueuedFrames(); @@ -69,19 +70,39 @@ void QuicPacketGenerator::AddControlFrame(const QuicFrame& frame) { QuicConsumedData QuicPacketGenerator::ConsumeData(QuicStreamId id, StringPiece data, QuicStreamOffset offset, - bool fin) { + bool fin, + QuicAckNotifier* notifier) { + IsHandshake handshake = id == kCryptoStreamId ? IS_HANDSHAKE : NOT_HANDSHAKE; + // The caller should have flushed pending frames before sending handshake + // messages. + DCHECK(handshake == NOT_HANDSHAKE || !HasPendingFrames()); SendQueuedFrames(); size_t total_bytes_consumed = 0; bool fin_consumed = false; + if (!packet_creator_->HasRoomForStreamFrame(id, offset)) { + SerializeAndSendPacket(); + } while (delegate_->CanWrite(NOT_RETRANSMISSION, HAS_RETRANSMITTABLE_DATA, - NOT_HANDSHAKE)) { + handshake)) { QuicFrame frame; - size_t bytes_consumed = packet_creator_->CreateStreamFrame( + size_t bytes_consumed; + if (notifier != NULL) { + // We want to track which packet this stream frame ends up in. + bytes_consumed = packet_creator_->CreateStreamFrameWithNotifier( + id, data, offset + total_bytes_consumed, fin, notifier, &frame); + } else { + bytes_consumed = packet_creator_->CreateStreamFrame( id, data, offset + total_bytes_consumed, fin, &frame); - bool success = AddFrame(frame); - DCHECK(success); + } + if (!AddFrame(frame)) { + LOG(DFATAL) << "Failed to add stream frame."; + // Inability to add a STREAM frame creates an unrecoverable hole in a + // the stream, so it's best to close the connection. + delegate_->CloseConnection(QUIC_INTERNAL_ERROR, false); + return QuicConsumedData(0, false); + } total_bytes_consumed += bytes_consumed; fin_consumed = fin && bytes_consumed == data.size(); @@ -89,7 +110,7 @@ QuicConsumedData QuicPacketGenerator::ConsumeData(QuicStreamId id, DCHECK(data.empty() || packet_creator_->BytesFree() == 0u); // TODO(ianswett): Restore packet reordering. - if (should_flush_ || !packet_creator_->HasRoomForStreamFrame(id, offset)) { + if (!InBatchMode() || !packet_creator_->HasRoomForStreamFrame(id, offset)) { SerializeAndSendPacket(); } @@ -101,15 +122,15 @@ QuicConsumedData QuicPacketGenerator::ConsumeData(QuicStreamId id, } } - // Ensure the FEC group is closed at the end of this method unless other - // writes are pending. - if (should_flush_ && packet_creator_->ShouldSendFec(true)) { + // Ensure the FEC group is closed at the end of this method if not in batch + // mode. + if (!InBatchMode() && packet_creator_->ShouldSendFec(true)) { SerializedPacket serialized_fec = packet_creator_->SerializeFec(); DCHECK(serialized_fec.packet); delegate_->OnSerializedPacket(serialized_fec); } - DCHECK(!should_flush_ || !packet_creator_->HasPendingFrames()); + DCHECK(InBatchMode() || !packet_creator_->HasPendingFrames()); return QuicConsumedData(total_bytes_consumed, fin_consumed); } @@ -135,7 +156,7 @@ void QuicPacketGenerator::SendQueuedFrames() { } } - if (should_flush_) { + if (!InBatchMode()) { if (packet_creator_->HasPendingFrames()) { SerializeAndSendPacket(); } @@ -151,12 +172,16 @@ void QuicPacketGenerator::SendQueuedFrames() { } } +bool QuicPacketGenerator::InBatchMode() { + return batch_mode_; +} + void QuicPacketGenerator::StartBatchOperations() { - should_flush_ = false; + batch_mode_ = true; } void QuicPacketGenerator::FinishBatchOperations() { - should_flush_ = true; + batch_mode_ = false; SendQueuedFrames(); } @@ -171,7 +196,7 @@ bool QuicPacketGenerator::HasPendingFrames() const { bool QuicPacketGenerator::AddNextPendingFrame() { if (should_send_ack_) { - pending_ack_frame_.reset((delegate_->CreateAckFrame())); + pending_ack_frame_.reset(delegate_->CreateAckFrame()); // If we can't this add the frame now, then we still need to do so later. should_send_ack_ = !AddFrame(QuicFrame(pending_ack_frame_.get())); // Return success if we have cleared out this flag (i.e., added the frame). @@ -180,7 +205,7 @@ bool QuicPacketGenerator::AddNextPendingFrame() { } if (should_send_feedback_) { - pending_feedback_frame_.reset((delegate_->CreateFeedbackFrame())); + pending_feedback_frame_.reset(delegate_->CreateFeedbackFrame()); // If we can't this add the frame now, then we still need to do so later. should_send_feedback_ = !AddFrame(QuicFrame(pending_feedback_frame_.get())); // Return success if we have cleared out this flag (i.e., added the frame). diff --git a/chromium/net/quic/quic_packet_generator.h b/chromium/net/quic/quic_packet_generator.h index e8b09e64634..940b259adf7 100644 --- a/chromium/net/quic/quic_packet_generator.h +++ b/chromium/net/quic/quic_packet_generator.h @@ -15,7 +15,7 @@ // If the Delegate is not writable, then no operations will cause // a packet to be serialized. In particular: // * SetShouldSendAck will simply record that an ack is to be sent. -// * AddControlFram will enqueue the control frame. +// * AddControlFrame will enqueue the control frame. // * ConsumeData will do nothing. // // If the Delegate is writable, then the behavior depends on the second @@ -57,6 +57,8 @@ namespace net { +class QuicAckNotifier; + class NET_EXPORT_PRIVATE QuicPacketGenerator { public: class NET_EXPORT_PRIVATE DelegateInterface { @@ -69,6 +71,7 @@ class NET_EXPORT_PRIVATE QuicPacketGenerator { virtual QuicCongestionFeedbackFrame* CreateFeedbackFrame() = 0; // Takes ownership of |packet.packet| and |packet.retransmittable_frames|. virtual bool OnSerializedPacket(const SerializedPacket& packet) = 0; + virtual void CloseConnection(QuicErrorCode error, bool from_peer) = 0; }; // Interface which gets callbacks from the QuicPacketGenerator at interesting @@ -88,13 +91,28 @@ class NET_EXPORT_PRIVATE QuicPacketGenerator { virtual ~QuicPacketGenerator(); + // Indicates that an ACK frame should be sent. If |also_send_feedback| is + // true, then it also indicates a CONGESTION_FEEDBACK frame should be sent. + // The contents of the frame(s) will be generated via a call to the delegates + // CreateAckFrame() and CreateFeedbackFrame() when the packet is serialized. void SetShouldSendAck(bool also_send_feedback); void AddControlFrame(const QuicFrame& frame); + + // Given some data, may consume part or all of it and pass it to the + // packet creator to be serialized into packets. If not in batch + // mode, these packets will also be sent during this call. Also + // attaches a QuicAckNotifier to any created stream frames, which + // will be called once the frame is ACKed by the peer. The + // QuicAckNotifier is owned by the QuicConnection. |notifier| may + // be NULL. QuicConsumedData ConsumeData(QuicStreamId id, base::StringPiece data, QuicStreamOffset offset, - bool fin); + bool fin, + QuicAckNotifier* notifier); + // Indicates whether batch mode is currently enabled. + bool InBatchMode(); // Disables flushing. void StartBatchOperations(); // Enables flushing and flushes queued data. @@ -119,6 +137,7 @@ class NET_EXPORT_PRIVATE QuicPacketGenerator { bool AddNextPendingFrame(); bool AddFrame(const QuicFrame& frame); + void SerializeAndSendPacket(); DelegateInterface* delegate_; @@ -126,7 +145,10 @@ class NET_EXPORT_PRIVATE QuicPacketGenerator { QuicPacketCreator* packet_creator_; QuicFrames queued_control_frames_; - bool should_flush_; + + // True if batch mode is currently enabled. + bool batch_mode_; + // Flags to indicate the need for just-in-time construction of a frame. bool should_send_ack_; bool should_send_feedback_; diff --git a/chromium/net/quic/quic_packet_generator_test.cc b/chromium/net/quic/quic_packet_generator_test.cc index 45556f174e9..ec8dd56a677 100644 --- a/chromium/net/quic/quic_packet_generator_test.cc +++ b/chromium/net/quic/quic_packet_generator_test.cc @@ -21,6 +21,7 @@ using std::string; using testing::InSequence; using testing::Return; using testing::SaveArg; +using testing::StrictMock; using testing::_; namespace net { @@ -39,6 +40,7 @@ class MockDelegate : public QuicPacketGenerator::DelegateInterface { MOCK_METHOD0(CreateAckFrame, QuicAckFrame*()); MOCK_METHOD0(CreateFeedbackFrame, QuicCongestionFeedbackFrame*()); MOCK_METHOD1(OnSerializedPacket, bool(const SerializedPacket& packet)); + MOCK_METHOD2(CloseConnection, void(QuicErrorCode, bool)); void SetCanWriteAnything() { EXPECT_CALL(*this, CanWrite(NOT_RETRANSMISSION, _, _)) @@ -98,11 +100,11 @@ class QuicPacketGeneratorTest : public ::testing::Test { : framer_(QuicVersionMax(), QuicTime::Zero(), false), creator_(42, &framer_, &random_, false), generator_(&delegate_, NULL, &creator_), - packet_(0, NULL, 0, NULL), - packet2_(0, NULL, 0, NULL), - packet3_(0, NULL, 0, NULL), - packet4_(0, NULL, 0, NULL), - packet5_(0, NULL, 0, NULL) { + packet_(0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL), + packet2_(0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL), + packet3_(0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL), + packet4_(0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL), + packet5_(0, PACKET_1BYTE_SEQUENCE_NUMBER, NULL, 0, NULL) { } ~QuicPacketGeneratorTest() { @@ -198,7 +200,7 @@ class QuicPacketGeneratorTest : public ::testing::Test { QuicFramer framer_; MockRandom random_; QuicPacketCreator creator_; - testing::StrictMock<MockDelegate> delegate_; + StrictMock<MockDelegate> delegate_; QuicPacketGenerator generator_; SimpleQuicFramer simple_framer_; SerializedPacket packet_; @@ -211,6 +213,12 @@ class QuicPacketGeneratorTest : public ::testing::Test { scoped_ptr<char[]> data_array_; }; +class MockDebugDelegate : public QuicPacketGenerator::DebugDelegateInterface { + public: + MOCK_METHOD1(OnFrameAddedToPacket, + void(const QuicFrame&)); +}; + TEST_F(QuicPacketGeneratorTest, ShouldSendAck_NotWritable) { delegate_.SetCanNotWrite(); @@ -219,10 +227,14 @@ TEST_F(QuicPacketGeneratorTest, ShouldSendAck_NotWritable) { } TEST_F(QuicPacketGeneratorTest, ShouldSendAck_WritableAndShouldNotFlush) { + StrictMock<MockDebugDelegate> debug_delegate; + + generator_.set_debug_delegate(&debug_delegate); delegate_.SetCanWriteOnlyNonRetransmittable(); generator_.StartBatchOperations(); EXPECT_CALL(delegate_, CreateAckFrame()).WillOnce(Return(CreateAckFrame())); + EXPECT_CALL(debug_delegate, OnFrameAddedToPacket(_)).Times(1); generator_.SetShouldSendAck(false); EXPECT_TRUE(generator_.HasQueuedFrames()); @@ -315,7 +327,7 @@ TEST_F(QuicPacketGeneratorTest, AddControlFrame_WritableAndShouldFlush) { TEST_F(QuicPacketGeneratorTest, ConsumeData_NotWritable) { delegate_.SetCanNotWrite(); - QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 2, true); + QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 2, true, NULL); EXPECT_EQ(0u, consumed.bytes_consumed); EXPECT_FALSE(consumed.fin_consumed); EXPECT_FALSE(generator_.HasQueuedFrames()); @@ -325,7 +337,7 @@ TEST_F(QuicPacketGeneratorTest, ConsumeData_WritableAndShouldNotFlush) { delegate_.SetCanWriteAnything(); generator_.StartBatchOperations(); - QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 2, true); + QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 2, true, NULL); EXPECT_EQ(3u, consumed.bytes_consumed); EXPECT_TRUE(consumed.fin_consumed); EXPECT_TRUE(generator_.HasQueuedFrames()); @@ -336,7 +348,7 @@ TEST_F(QuicPacketGeneratorTest, ConsumeData_WritableAndShouldFlush) { EXPECT_CALL(delegate_, OnSerializedPacket(_)).WillOnce( DoAll(SaveArg<0>(&packet_), Return(true))); - QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 2, true); + QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 2, true, NULL); EXPECT_EQ(3u, consumed.bytes_consumed); EXPECT_TRUE(consumed.fin_consumed); EXPECT_FALSE(generator_.HasQueuedFrames()); @@ -351,8 +363,8 @@ TEST_F(QuicPacketGeneratorTest, delegate_.SetCanWriteAnything(); generator_.StartBatchOperations(); - generator_.ConsumeData(1, "foo", 2, true); - QuicConsumedData consumed = generator_.ConsumeData(3, "quux", 7, false); + generator_.ConsumeData(1, "foo", 2, true, NULL); + QuicConsumedData consumed = generator_.ConsumeData(3, "quux", 7, false, NULL); EXPECT_EQ(4u, consumed.bytes_consumed); EXPECT_FALSE(consumed.fin_consumed); EXPECT_TRUE(generator_.HasQueuedFrames()); @@ -362,8 +374,8 @@ TEST_F(QuicPacketGeneratorTest, ConsumeData_BatchOperations) { delegate_.SetCanWriteAnything(); generator_.StartBatchOperations(); - generator_.ConsumeData(1, "foo", 2, true); - QuicConsumedData consumed = generator_.ConsumeData(3, "quux", 7, false); + generator_.ConsumeData(1, "foo", 2, true, NULL); + QuicConsumedData consumed = generator_.ConsumeData(3, "quux", 7, false, NULL); EXPECT_EQ(4u, consumed.bytes_consumed); EXPECT_FALSE(consumed.fin_consumed); EXPECT_TRUE(generator_.HasQueuedFrames()); @@ -402,7 +414,7 @@ TEST_F(QuicPacketGeneratorTest, ConsumeDataFEC) { // Send enough data to create 3 packets: two full and one partial. size_t data_len = 2 * kMaxPacketSize + 100; QuicConsumedData consumed = - generator_.ConsumeData(3, CreateData(data_len), 0, true); + generator_.ConsumeData(3, CreateData(data_len), 0, true, NULL); EXPECT_EQ(data_len, consumed.bytes_consumed); EXPECT_TRUE(consumed.fin_consumed); EXPECT_FALSE(generator_.HasQueuedFrames()); @@ -434,7 +446,7 @@ TEST_F(QuicPacketGeneratorTest, ConsumeDataSendsFecAtEnd) { // Send enough data to create 2 packets: one full and one partial. size_t data_len = 1 * kMaxPacketSize + 100; QuicConsumedData consumed = - generator_.ConsumeData(3, CreateData(data_len), 0, true); + generator_.ConsumeData(3, CreateData(data_len), 0, true, NULL); EXPECT_EQ(data_len, consumed.bytes_consumed); EXPECT_TRUE(consumed.fin_consumed); EXPECT_FALSE(generator_.HasQueuedFrames()); @@ -444,6 +456,49 @@ TEST_F(QuicPacketGeneratorTest, ConsumeDataSendsFecAtEnd) { CheckPacketIsFec(packet3_, 1); } +TEST_F(QuicPacketGeneratorTest, ConsumeData_FramesPreviouslyQueued) { + // Set the packet size be enough for two stream frames with 0 stream offset, + // but not enough for a stream frame of 0 offset and one with non-zero offset. + creator_.options()->max_packet_length = + NullEncrypter().GetCiphertextSize(0) + + GetPacketHeaderSize(creator_.options()->send_guid_length, + true, + creator_.options()->send_sequence_number_length, + NOT_IN_FEC_GROUP) + + // Add an extra 3 bytes for the payload and 1 byte so BytesFree is larger + // than the GetMinStreamFrameSize. + QuicFramer::GetMinStreamFrameSize(framer_.version(), 1, 0, false) + 3 + + QuicFramer::GetMinStreamFrameSize(framer_.version(), 1, 0, true) + 1; + delegate_.SetCanWriteAnything(); + { + InSequence dummy; + EXPECT_CALL(delegate_, OnSerializedPacket(_)).WillOnce( + DoAll(SaveArg<0>(&packet_), Return(true))); + EXPECT_CALL(delegate_, OnSerializedPacket(_)).WillOnce( + DoAll(SaveArg<0>(&packet2_), Return(true))); + } + generator_.StartBatchOperations(); + // Queue enough data to prevent a stream frame with a non-zero offset from + // fitting. + QuicConsumedData consumed = generator_.ConsumeData(1, "foo", 0, false, NULL); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); + EXPECT_TRUE(generator_.HasQueuedFrames()); + + // This frame will not fit with the existing frame, causing the queued frame + // to be serialized, and it will not fit with another frame like it, so it is + // serialized by itself. + consumed = generator_.ConsumeData(1, "bar", 3, true, NULL); + EXPECT_EQ(3u, consumed.bytes_consumed); + EXPECT_TRUE(consumed.fin_consumed); + EXPECT_FALSE(generator_.HasQueuedFrames()); + + PacketContents contents; + contents.num_stream_frames = 1; + CheckPacketContains(contents, packet_); + CheckPacketContains(contents, packet2_); +} + TEST_F(QuicPacketGeneratorTest, NotWritableThenBatchOperations) { delegate_.SetCanNotWrite(); @@ -462,7 +517,7 @@ TEST_F(QuicPacketGeneratorTest, NotWritableThenBatchOperations) { Return(CreateFeedbackFrame())); // Send some data and a control frame - generator_.ConsumeData(3, "quux", 7, false); + generator_.ConsumeData(3, "quux", 7, false, NULL); generator_.AddControlFrame(QuicFrame(CreateGoAwayFrame())); // All five frames will be flushed out in a single packet. @@ -509,7 +564,7 @@ TEST_F(QuicPacketGeneratorTest, NotWritableThenBatchOperations2) { // Send enough data to exceed one packet size_t data_len = kMaxPacketSize + 100; QuicConsumedData consumed = - generator_.ConsumeData(3, CreateData(data_len), 0, true); + generator_.ConsumeData(3, CreateData(data_len), 0, true, NULL); EXPECT_EQ(data_len, consumed.bytes_consumed); EXPECT_TRUE(consumed.fin_consumed); generator_.AddControlFrame(QuicFrame(CreateGoAwayFrame())); diff --git a/chromium/net/quic/quic_protocol.cc b/chromium/net/quic/quic_protocol.cc index 4c2e5527971..cdf3c6ce4d9 100644 --- a/chromium/net/quic/quic_protocol.cc +++ b/chromium/net/quic/quic_protocol.cc @@ -74,15 +74,6 @@ QuicPacketPublicHeader::QuicPacketPublicHeader( QuicPacketPublicHeader::~QuicPacketPublicHeader() {} -QuicPacketPublicHeader& QuicPacketPublicHeader::operator=( - const QuicPacketPublicHeader& other) { - guid = other.guid; - reset_flag = other.reset_flag; - version_flag = other.version_flag; - versions = other.versions; - return *this; -} - QuicPacketHeader::QuicPacketHeader() : fec_flag(false), entropy_flag(false), @@ -111,7 +102,8 @@ QuicStreamFrame::QuicStreamFrame(QuicStreamId stream_id, : stream_id(stream_id), fin(fin), offset(offset), - data(data) { + data(data), + notifier(NULL) { } uint32 MakeQuicTag(char a, char b, char c, char d) { @@ -129,10 +121,12 @@ QuicVersion QuicVersionMin() { QuicTag QuicVersionToQuicTag(const QuicVersion version) { switch (version) { - case QUIC_VERSION_7: - return MakeQuicTag('Q', '0', '0', '7'); case QUIC_VERSION_8: return MakeQuicTag('Q', '0', '0', '8'); + case QUIC_VERSION_9: + return MakeQuicTag('Q', '0', '0', '9'); + case QUIC_VERSION_10: + return MakeQuicTag('Q', '0', '1', '0'); default: // This shold be an ERROR because we should never attempt to convert an // invalid QuicVersion to be written to the wire. @@ -142,13 +136,16 @@ QuicTag QuicVersionToQuicTag(const QuicVersion version) { } QuicVersion QuicTagToQuicVersion(const QuicTag version_tag) { - const QuicTag quic_tag_v7 = MakeQuicTag('Q', '0', '0', '7'); const QuicTag quic_tag_v8 = MakeQuicTag('Q', '0', '0', '8'); + const QuicTag quic_tag_v9 = MakeQuicTag('Q', '0', '0', '9'); + const QuicTag quic_tag_v10 = MakeQuicTag('Q', '0', '1', '0'); - if (version_tag == quic_tag_v7) { - return QUIC_VERSION_7; - } else if (version_tag == quic_tag_v8) { + if (version_tag == quic_tag_v8) { return QUIC_VERSION_8; + } else if (version_tag == quic_tag_v9) { + return QUIC_VERSION_9; + } else if (version_tag == quic_tag_v10) { + return QUIC_VERSION_10; } else { // Reading from the client so this should not be considered an ERROR. DLOG(INFO) << "Unsupported QuicTag version: " @@ -163,8 +160,9 @@ return #x string QuicVersionToString(const QuicVersion version) { switch (version) { - RETURN_STRING_LITERAL(QUIC_VERSION_7); RETURN_STRING_LITERAL(QUIC_VERSION_8); + RETURN_STRING_LITERAL(QUIC_VERSION_9); + RETURN_STRING_LITERAL(QUIC_VERSION_10); default: return "QUIC_VERSION_UNSUPPORTED"; } @@ -184,6 +182,8 @@ string QuicVersionArrayToString(const QuicVersion versions[], ostream& operator<<(ostream& os, const QuicPacketHeader& header) { os << "{ guid: " << header.public_header.guid << ", guid_length:" << header.public_header.guid_length + << ", sequence_number_length:" + << header.public_header.sequence_number_length << ", reset_flag: " << header.public_header.reset_flag << ", version_flag: " << header.public_header.version_flag; if (header.public_header.version_flag) { @@ -415,6 +415,21 @@ void RetransmittableFrames::set_encryption_level(EncryptionLevel level) { encryption_level_ = level; } +SerializedPacket::SerializedPacket( + QuicPacketSequenceNumber sequence_number, + QuicSequenceNumberLength sequence_number_length, + QuicPacket* packet, + QuicPacketEntropyHash entropy_hash, + RetransmittableFrames* retransmittable_frames) + : sequence_number(sequence_number), + sequence_number_length(sequence_number_length), + packet(packet), + entropy_hash(entropy_hash), + retransmittable_frames(retransmittable_frames) { +} + +SerializedPacket::~SerializedPacket() {} + ostream& operator<<(ostream& os, const QuicEncryptedPacket& s) { os << s.length() << "-byte data"; return os; diff --git a/chromium/net/quic/quic_protocol.h b/chromium/net/quic/quic_protocol.h index 737bb16106f..26e6c027682 100644 --- a/chromium/net/quic/quic_protocol.h +++ b/chromium/net/quic/quic_protocol.h @@ -27,6 +27,7 @@ namespace net { using ::operator<<; +class QuicAckNotifier; class QuicPacket; struct QuicPacketHeader; @@ -41,6 +42,7 @@ typedef uint32 QuicHeaderId; // QuicTag is the type of a tag in the wire protocol. typedef uint32 QuicTag; typedef std::vector<QuicTag> QuicTagVector; +typedef uint32 QuicPriority; // TODO(rch): Consider Quic specific names for these constants. // Maximum size in bytes of a QUIC packet. @@ -98,12 +100,12 @@ enum IsHandshake { enum QuicFrameType { PADDING_FRAME = 0, - STREAM_FRAME, - ACK_FRAME, - CONGESTION_FEEDBACK_FRAME, RST_STREAM_FRAME, CONNECTION_CLOSE_FRAME, GOAWAY_FRAME, + STREAM_FRAME, + ACK_FRAME, + CONGESTION_FEEDBACK_FRAME, NUM_FRAME_TYPES }; @@ -187,8 +189,9 @@ enum QuicVersion { // Special case to indicate unknown/unsupported QUIC version. QUIC_VERSION_UNSUPPORTED = 0, - QUIC_VERSION_7 = 7, - QUIC_VERSION_8 = 8, // Current version. + QUIC_VERSION_8 = 8, + QUIC_VERSION_9 = 9, + QUIC_VERSION_10 = 10, // Current version. }; // This vector contains QUIC versions which we currently support. @@ -196,7 +199,7 @@ enum QuicVersion { // element, with subsequent elements in descending order (versions can be // skipped as necessary). static const QuicVersion kSupportedQuicVersions[] = - {QUIC_VERSION_8, QUIC_VERSION_7}; + {QUIC_VERSION_10, QUIC_VERSION_9}; typedef std::vector<QuicVersion> QuicVersionVector; @@ -216,6 +219,10 @@ NET_EXPORT_PRIVATE QuicTag QuicVersionToQuicTag(const QuicVersion version); // Returns QUIC_VERSION_UNSUPPORTED if version_tag cannot be understood. NET_EXPORT_PRIVATE QuicVersion QuicTagToQuicVersion(const QuicTag version_tag); +// Returns the appropriate QuicTag for a properly formed version string +// (e.g. Q008). +NET_EXPORT_PRIVATE QuicTag StringToQuicTag(std::string version); + // Helper function which translates from a QuicVersion to a string. // Returns strings corresponding to enum names (e.g. QUIC_VERSION_6). NET_EXPORT_PRIVATE std::string QuicVersionToString(const QuicVersion version); @@ -262,8 +269,8 @@ NET_EXPORT_PRIVATE size_t GetStartOfEncryptedData( enum QuicRstStreamErrorCode { QUIC_STREAM_NO_ERROR = 0, - // There was some server error which halted stream processing. - QUIC_SERVER_ERROR_PROCESSING_STREAM, + // There was some error which halted stream processing. + QUIC_ERROR_PROCESSING_STREAM, // We got two fin or reset offsets which did not match. QUIC_MULTIPLE_TERMINATION_OFFSETS, // We got bad payload and can not respond to it at the protocol level. @@ -273,115 +280,132 @@ enum QuicRstStreamErrorCode { QUIC_STREAM_CONNECTION_ERROR, // GoAway frame sent. No more stream can be created. QUIC_STREAM_PEER_GOING_AWAY, + // The stream has been cancelled. + QUIC_STREAM_CANCELLED, // No error. Used as bound while iterating. QUIC_STREAM_LAST_ERROR, }; +// These values must remain stable as they are uploaded to UMA histograms. +// To add a new error code, use the current value of QUIC_LAST_ERROR and +// increment QUIC_LAST_ERROR. enum QuicErrorCode { QUIC_NO_ERROR = 0, // Connection has reached an invalid state. - QUIC_INTERNAL_ERROR, + QUIC_INTERNAL_ERROR = 1, // There were data frames after the a fin or reset. - QUIC_STREAM_DATA_AFTER_TERMINATION, + QUIC_STREAM_DATA_AFTER_TERMINATION = 2, // Control frame is malformed. - QUIC_INVALID_PACKET_HEADER, + QUIC_INVALID_PACKET_HEADER = 3, // Frame data is malformed. - QUIC_INVALID_FRAME_DATA, + QUIC_INVALID_FRAME_DATA = 4, + // The packet contained no payload. + QUIC_MISSING_PAYLOAD = 48, // FEC data is malformed. - QUIC_INVALID_FEC_DATA, - // Stream rst data is malformed - QUIC_INVALID_RST_STREAM_DATA, - // Connection close data is malformed. - QUIC_INVALID_CONNECTION_CLOSE_DATA, - // GoAway data is malformed. - QUIC_INVALID_GOAWAY_DATA, - // Ack data is malformed. - QUIC_INVALID_ACK_DATA, + QUIC_INVALID_FEC_DATA = 5, + // STREAM frame data is malformed. + QUIC_INVALID_STREAM_DATA = 46, + // RST_STREAM frame data is malformed. + QUIC_INVALID_RST_STREAM_DATA = 6, + // CONNECTION_CLOSE frame data is malformed. + QUIC_INVALID_CONNECTION_CLOSE_DATA = 7, + // GOAWAY frame data is malformed. + QUIC_INVALID_GOAWAY_DATA = 8, + // ACK frame data is malformed. + QUIC_INVALID_ACK_DATA = 9, + // CONGESTION_FEEDBACK frame data is malformed. + QUIC_INVALID_CONGESTION_FEEDBACK_DATA = 47, // Version negotiation packet is malformed. - QUIC_INVALID_VERSION_NEGOTIATION_PACKET, + QUIC_INVALID_VERSION_NEGOTIATION_PACKET = 10, // Public RST packet is malformed. - QUIC_INVALID_PUBLIC_RST_PACKET, + QUIC_INVALID_PUBLIC_RST_PACKET = 11, // There was an error decrypting. - QUIC_DECRYPTION_FAILURE, + QUIC_DECRYPTION_FAILURE = 12, // There was an error encrypting. - QUIC_ENCRYPTION_FAILURE, + QUIC_ENCRYPTION_FAILURE = 13, // The packet exceeded kMaxPacketSize. - QUIC_PACKET_TOO_LARGE, + QUIC_PACKET_TOO_LARGE = 14, // Data was sent for a stream which did not exist. - QUIC_PACKET_FOR_NONEXISTENT_STREAM, + QUIC_PACKET_FOR_NONEXISTENT_STREAM = 15, // The peer is going away. May be a client or server. - QUIC_PEER_GOING_AWAY, + QUIC_PEER_GOING_AWAY = 16, // A stream ID was invalid. - QUIC_INVALID_STREAM_ID, + QUIC_INVALID_STREAM_ID = 17, + // A priority was invalid. + QUIC_INVALID_PRIORITY = 49, // Too many streams already open. - QUIC_TOO_MANY_OPEN_STREAMS, + QUIC_TOO_MANY_OPEN_STREAMS = 18, // Received public reset for this connection. - QUIC_PUBLIC_RESET, + QUIC_PUBLIC_RESET = 19, // Invalid protocol version. - QUIC_INVALID_VERSION, + QUIC_INVALID_VERSION = 20, // Stream reset before headers decompressed. - QUIC_STREAM_RST_BEFORE_HEADERS_DECOMPRESSED, + QUIC_STREAM_RST_BEFORE_HEADERS_DECOMPRESSED = 21, // The Header ID for a stream was too far from the previous. - QUIC_INVALID_HEADER_ID, + QUIC_INVALID_HEADER_ID = 22, // Negotiable parameter received during handshake had invalid value. - QUIC_INVALID_NEGOTIATED_VALUE, + QUIC_INVALID_NEGOTIATED_VALUE = 23, // There was an error decompressing data. - QUIC_DECOMPRESSION_FAILURE, + QUIC_DECOMPRESSION_FAILURE = 24, // We hit our prenegotiated (or default) timeout - QUIC_CONNECTION_TIMED_OUT, + QUIC_CONNECTION_TIMED_OUT = 25, // There was an error encountered migrating addresses - QUIC_ERROR_MIGRATING_ADDRESS, - // There was an error while writing the packet. - QUIC_PACKET_WRITE_ERROR, + QUIC_ERROR_MIGRATING_ADDRESS = 26, + // There was an error while writing to the socket. + QUIC_PACKET_WRITE_ERROR = 27, + // There was an error while reading from the socket. + QUIC_PACKET_READ_ERROR = 51, + // We received a STREAM_FRAME with no data and no fin flag set. + QUIC_INVALID_STREAM_FRAME = 50, // Crypto errors. // Hanshake failed. - QUIC_HANDSHAKE_FAILED, + QUIC_HANDSHAKE_FAILED = 28, // Handshake message contained out of order tags. - QUIC_CRYPTO_TAGS_OUT_OF_ORDER, + QUIC_CRYPTO_TAGS_OUT_OF_ORDER = 29, // Handshake message contained too many entries. - QUIC_CRYPTO_TOO_MANY_ENTRIES, + QUIC_CRYPTO_TOO_MANY_ENTRIES = 30, // Handshake message contained an invalid value length. - QUIC_CRYPTO_INVALID_VALUE_LENGTH, + QUIC_CRYPTO_INVALID_VALUE_LENGTH = 31, // A crypto message was received after the handshake was complete. - QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE, + QUIC_CRYPTO_MESSAGE_AFTER_HANDSHAKE_COMPLETE = 32, // A crypto message was received with an illegal message tag. - QUIC_INVALID_CRYPTO_MESSAGE_TYPE, + QUIC_INVALID_CRYPTO_MESSAGE_TYPE = 33, // A crypto message was received with an illegal parameter. - QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER, + QUIC_INVALID_CRYPTO_MESSAGE_PARAMETER = 34, // A crypto message was received with a mandatory parameter missing. - QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND, + QUIC_CRYPTO_MESSAGE_PARAMETER_NOT_FOUND = 35, // A crypto message was received with a parameter that has no overlap // with the local parameter. - QUIC_CRYPTO_MESSAGE_PARAMETER_NO_OVERLAP, + QUIC_CRYPTO_MESSAGE_PARAMETER_NO_OVERLAP = 36, // A crypto message was received that contained a parameter with too few // values. - QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND, + QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND = 37, // An internal error occured in crypto processing. - QUIC_CRYPTO_INTERNAL_ERROR, + QUIC_CRYPTO_INTERNAL_ERROR = 38, // A crypto handshake message specified an unsupported version. - QUIC_CRYPTO_VERSION_NOT_SUPPORTED, + QUIC_CRYPTO_VERSION_NOT_SUPPORTED = 39, // There was no intersection between the crypto primitives supported by the // peer and ourselves. - QUIC_CRYPTO_NO_SUPPORT, + QUIC_CRYPTO_NO_SUPPORT = 40, // The server rejected our client hello messages too many times. - QUIC_CRYPTO_TOO_MANY_REJECTS, + QUIC_CRYPTO_TOO_MANY_REJECTS = 41, // The client rejected the server's certificate chain or signature. - QUIC_PROOF_INVALID, + QUIC_PROOF_INVALID = 42, // A crypto message was received with a duplicate tag. - QUIC_CRYPTO_DUPLICATE_TAG, + QUIC_CRYPTO_DUPLICATE_TAG = 43, // A crypto message was received with the wrong encryption level (i.e. it // should have been encrypted but was not.) - QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT, + QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT = 44, // The server config for a server has expired. - QUIC_CRYPTO_SERVER_CONFIG_EXPIRED, + QUIC_CRYPTO_SERVER_CONFIG_EXPIRED = 45, // No error. Used as bound while iterating. - QUIC_LAST_ERROR, + QUIC_LAST_ERROR = 52, }; struct NET_EXPORT_PRIVATE QuicPacketPublicHeader { @@ -389,8 +413,6 @@ struct NET_EXPORT_PRIVATE QuicPacketPublicHeader { explicit QuicPacketPublicHeader(const QuicPacketPublicHeader& other); ~QuicPacketPublicHeader(); - QuicPacketPublicHeader& operator=(const QuicPacketPublicHeader& other); - // Universal header. All QuicPacket headers will have a guid and public flags. QuicGuid guid; QuicGuidLength guid_length; @@ -457,6 +479,10 @@ struct NET_EXPORT_PRIVATE QuicStreamFrame { bool fin; QuicStreamOffset offset; // Location of this data in the stream. base::StringPiece data; + + // If this is set, then when this packet is ACKed the AckNotifier will be + // informed. + QuicAckNotifier* notifier; }; // TODO(ianswett): Re-evaluate the trade-offs of hash_set vs set when framing @@ -675,10 +701,6 @@ struct NET_EXPORT_PRIVATE QuicFecData { base::StringPiece redundancy; }; -struct NET_EXPORT_PRIVATE QuicPacketData { - std::string data; -}; - class NET_EXPORT_PRIVATE QuicData { public: QuicData(const char* buffer, size_t length) @@ -739,8 +761,6 @@ class NET_EXPORT_PRIVATE QuicPacket : public QuicData { bool is_fec_packet() const { return is_fec_packet_; } - bool includes_version() const { return includes_version_; } - char* mutable_data() { return buffer_; } private: @@ -815,28 +835,27 @@ class NET_EXPORT_PRIVATE RetransmittableFrames { struct NET_EXPORT_PRIVATE SerializedPacket { SerializedPacket(QuicPacketSequenceNumber sequence_number, + QuicSequenceNumberLength sequence_number_length, QuicPacket* packet, QuicPacketEntropyHash entropy_hash, - RetransmittableFrames* retransmittable_frames) - : sequence_number(sequence_number), - packet(packet), - entropy_hash(entropy_hash), - retransmittable_frames(retransmittable_frames) {} + RetransmittableFrames* retransmittable_frames); + ~SerializedPacket(); QuicPacketSequenceNumber sequence_number; + QuicSequenceNumberLength sequence_number_length; QuicPacket* packet; QuicPacketEntropyHash entropy_hash; RetransmittableFrames* retransmittable_frames; + + // If set, these will be called when this packet is ACKed by the peer. + std::set<QuicAckNotifier*> notifiers; }; // A struct for functions which consume data payloads and fins. -// The first member of the pair indicates bytes consumed. -// The second member of the pair indicates if an incoming fin was consumed. struct QuicConsumedData { QuicConsumedData(size_t bytes_consumed, bool fin_consumed) : bytes_consumed(bytes_consumed), fin_consumed(fin_consumed) {} - // By default, gtest prints the raw bytes of an object. The bool data // member causes this object to have padding bytes, which causes the // default gtest object printer to read uninitialize memory. So we need @@ -844,7 +863,10 @@ struct QuicConsumedData { NET_EXPORT_PRIVATE friend std::ostream& operator<<( std::ostream& os, const QuicConsumedData& s); + // How many bytes were consumed. size_t bytes_consumed; + + // True if an incoming fin was consumed. bool fin_consumed; }; diff --git a/chromium/net/quic/quic_protocol_test.cc b/chromium/net/quic/quic_protocol_test.cc index b073d859f5d..52ed6645c94 100644 --- a/chromium/net/quic/quic_protocol_test.cc +++ b/chromium/net/quic/quic_protocol_test.cc @@ -56,8 +56,8 @@ TEST(QuicProtocolTest, QuicVersionToQuicTag) { #endif // Explicitly test a specific version. - EXPECT_EQ(MakeQuicTag('Q', '0', '0', '7'), - QuicVersionToQuicTag(QUIC_VERSION_7)); + EXPECT_EQ(MakeQuicTag('Q', '0', '1', '0'), + QuicVersionToQuicTag(QUIC_VERSION_10)); // Loop over all supported versions and make sure that we never hit the // default case (i.e. all supported versions should be successfully converted @@ -95,8 +95,8 @@ TEST(QuicProtocolTest, QuicTagToQuicVersion) { #endif // Explicitly test specific versions. - EXPECT_EQ(QUIC_VERSION_7, - QuicTagToQuicVersion(MakeQuicTag('Q', '0', '0', '7'))); + EXPECT_EQ(QUIC_VERSION_10, + QuicTagToQuicVersion(MakeQuicTag('Q', '0', '1', '0'))); for (size_t i = 0; i < arraysize(kSupportedQuicVersions); ++i) { const QuicVersion& version = kSupportedQuicVersions[i]; @@ -127,16 +127,17 @@ TEST(QuicProtocolTest, QuicTagToQuicVersionUnsupported) { } TEST(QuicProtocolTest, QuicVersionToString) { - EXPECT_EQ("QUIC_VERSION_7", - QuicVersionToString(QUIC_VERSION_7)); + EXPECT_EQ("QUIC_VERSION_8", + QuicVersionToString(QUIC_VERSION_8)); EXPECT_EQ("QUIC_VERSION_UNSUPPORTED", QuicVersionToString(QUIC_VERSION_UNSUPPORTED)); - QuicVersion single_version[] = {QUIC_VERSION_7}; - EXPECT_EQ("QUIC_VERSION_7,", QuicVersionArrayToString(single_version, - arraysize(single_version))); - QuicVersion multiple_versions[] = {QUIC_VERSION_8, QUIC_VERSION_7}; - EXPECT_EQ("QUIC_VERSION_8,QUIC_VERSION_7,", + QuicVersion single_version[] = {QUIC_VERSION_8}; + EXPECT_EQ("QUIC_VERSION_8,", QuicVersionArrayToString( + single_version, arraysize(single_version))); + QuicVersion multiple_versions[] = + {QUIC_VERSION_10, QUIC_VERSION_9, QUIC_VERSION_8}; + EXPECT_EQ("QUIC_VERSION_10,QUIC_VERSION_9,QUIC_VERSION_8,", QuicVersionArrayToString(multiple_versions, arraysize(multiple_versions))); } diff --git a/chromium/net/quic/quic_reliable_client_stream.cc b/chromium/net/quic/quic_reliable_client_stream.cc index 1951ae85a55..06b3178cdad 100644 --- a/chromium/net/quic/quic_reliable_client_stream.cc +++ b/chromium/net/quic/quic_reliable_client_stream.cc @@ -4,8 +4,10 @@ #include "net/quic/quic_reliable_client_stream.h" +#include "base/callback_helpers.h" #include "net/base/net_errors.h" #include "net/quic/quic_session.h" +#include "net/spdy/write_blocked_list.h" namespace net { @@ -45,6 +47,37 @@ void QuicReliableClientStream::TerminateFromPeer(bool half_close) { ReliableQuicStream::TerminateFromPeer(half_close); } +void QuicReliableClientStream::OnCanWrite() { + ReliableQuicStream::OnCanWrite(); + + if (!HasBufferedData() && !callback_.is_null()) { + base::ResetAndReturn(&callback_).Run(OK); + } +} + +QuicPriority QuicReliableClientStream::EffectivePriority() const { + if (delegate_ && delegate_->HasSendHeadersComplete()) { + return ReliableQuicStream::EffectivePriority(); + } + return kHighestPriority; +} + +int QuicReliableClientStream::WriteStreamData( + base::StringPiece data, + bool fin, + const CompletionCallback& callback) { + // We should not have data buffered. + DCHECK(!HasBufferedData()); + // Writes the data, or buffers it. + WriteData(data, fin); + if (!HasBufferedData()) { + return OK; + } + + callback_ = callback; + return ERR_IO_PENDING; +} + void QuicReliableClientStream::SetDelegate( QuicReliableClientStream::Delegate* delegate) { DCHECK((!delegate_ && delegate) || (delegate_ && !delegate)); diff --git a/chromium/net/quic/quic_reliable_client_stream.h b/chromium/net/quic/quic_reliable_client_stream.h index 77ac787f10d..bf3fc158d45 100644 --- a/chromium/net/quic/quic_reliable_client_stream.h +++ b/chromium/net/quic/quic_reliable_client_stream.h @@ -48,6 +48,9 @@ class NET_EXPORT_PRIVATE QuicReliableClientStream : public ReliableQuicStream { // Called when the stream is closed because of an error. virtual void OnError(int error) = 0; + // Returns true if sending of headers has completed. + virtual bool HasSendHeadersComplete() = 0; + protected: virtual ~Delegate() {} @@ -64,8 +67,16 @@ class NET_EXPORT_PRIVATE QuicReliableClientStream : public ReliableQuicStream { // ReliableQuicStream virtual uint32 ProcessData(const char* data, uint32 data_len) OVERRIDE; virtual void TerminateFromPeer(bool half_close) OVERRIDE; - using ReliableQuicStream::WriteData; + virtual void OnCanWrite() OVERRIDE; + virtual QuicPriority EffectivePriority() const OVERRIDE; + + // While the server's set_priority shouldn't be called externally, the creator + // of client-side streams should be able to set the priority. + using ReliableQuicStream::set_priority; + int WriteStreamData(base::StringPiece data, + bool fin, + const CompletionCallback& callback); // Set new |delegate|. |delegate| must not be NULL. // If this stream has already received data, OnDataReceived() will be // called on the delegate. @@ -79,6 +90,8 @@ class NET_EXPORT_PRIVATE QuicReliableClientStream : public ReliableQuicStream { BoundNetLog net_log_; Delegate* delegate_; + CompletionCallback callback_; + DISALLOW_COPY_AND_ASSIGN(QuicReliableClientStream); }; diff --git a/chromium/net/quic/quic_reliable_client_stream_test.cc b/chromium/net/quic/quic_reliable_client_stream_test.cc index 12402fd23eb..aaebda27fe0 100644 --- a/chromium/net/quic/quic_reliable_client_stream_test.cc +++ b/chromium/net/quic/quic_reliable_client_stream_test.cc @@ -5,6 +5,7 @@ #include "net/quic/quic_reliable_client_stream.h" #include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" #include "net/quic/quic_client_session.h" #include "net/quic/quic_utils.h" #include "net/quic/test_tools/quic_test_utils.h" @@ -12,6 +13,7 @@ using testing::Return; using testing::StrEq; +using testing::_; namespace net { namespace test { @@ -26,6 +28,7 @@ class MockDelegate : public QuicReliableClientStream::Delegate { MOCK_METHOD2(OnDataReceived, int(const char*, int)); MOCK_METHOD1(OnClose, void(QuicErrorCode)); MOCK_METHOD1(OnError, void(int)); + MOCK_METHOD0(HasSendHeadersComplete, bool()); private: DISALLOW_COPY_AND_ASSIGN(MockDelegate); @@ -77,6 +80,44 @@ TEST_F(QuicReliableClientStreamTest, OnError) { EXPECT_FALSE(stream_.GetDelegate()); } +TEST_F(QuicReliableClientStreamTest, WriteStreamData) { + EXPECT_CALL(delegate_, OnClose(QUIC_NO_ERROR)); + + const char kData1[] = "hello world"; + const size_t kDataLen = arraysize(kData1); + + // All data written. + EXPECT_CALL(session_, WritevData(stream_.id(), _, _, _, _)).WillOnce( + Return(QuicConsumedData(kDataLen, true))); + TestCompletionCallback callback; + EXPECT_EQ(OK, stream_.WriteStreamData(base::StringPiece(kData1, kDataLen), + true, callback.callback())); +} + +TEST_F(QuicReliableClientStreamTest, WriteStreamDataAsync) { + EXPECT_CALL(delegate_, HasSendHeadersComplete()); + EXPECT_CALL(delegate_, OnClose(QUIC_NO_ERROR)); + + const char kData1[] = "hello world"; + const size_t kDataLen = arraysize(kData1); + + // No data written. + EXPECT_CALL(session_, WritevData(stream_.id(), _, _, _, _)).WillOnce( + Return(QuicConsumedData(0, false))); + TestCompletionCallback callback; + EXPECT_EQ(ERR_IO_PENDING, + stream_.WriteStreamData(base::StringPiece(kData1, kDataLen), + true, callback.callback())); + ASSERT_FALSE(callback.have_result()); + + // All data written. + EXPECT_CALL(session_, WritevData(stream_.id(), _, _, _, _)).WillOnce( + Return(QuicConsumedData(kDataLen, true))); + stream_.OnCanWrite(); + ASSERT_TRUE(callback.have_result()); + EXPECT_EQ(OK, callback.WaitForResult()); +} + } // namespace } // namespace test } // namespace net diff --git a/chromium/net/quic/quic_sent_entropy_manager.h b/chromium/net/quic/quic_sent_entropy_manager.h index 4f684fc9a73..a101e738d06 100644 --- a/chromium/net/quic/quic_sent_entropy_manager.h +++ b/chromium/net/quic/quic_sent_entropy_manager.h @@ -40,10 +40,6 @@ class NET_EXPORT_PRIVATE QuicSentEntropyManager { // |sequence_number|. void ClearEntropyBefore(QuicPacketSequenceNumber sequence_number); - QuicPacketEntropyHash packets_entropy_hash() const { - return packets_entropy_hash_; - } - private: typedef linked_hash_map<QuicPacketSequenceNumber, std::pair<QuicPacketEntropyHash, diff --git a/chromium/net/quic/quic_sent_packet_manager.cc b/chromium/net/quic/quic_sent_packet_manager.cc new file mode 100644 index 00000000000..a8960f07aa0 --- /dev/null +++ b/chromium/net/quic/quic_sent_packet_manager.cc @@ -0,0 +1,221 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/quic/quic_sent_packet_manager.h" + +#include "base/logging.h" +#include "base/stl_util.h" + +using std::make_pair; + +namespace net { + +#define ENDPOINT (is_server_ ? "Server: " : " Client: ") + +QuicSentPacketManager::HelperInterface::~HelperInterface() { +} + +QuicSentPacketManager::QuicSentPacketManager(bool is_server, + HelperInterface* helper) + : is_server_(is_server), + helper_(helper) { +} + +QuicSentPacketManager::~QuicSentPacketManager() { + STLDeleteValues(&unacked_packets_); +} + +void QuicSentPacketManager::OnSerializedPacket( + const SerializedPacket& serialized_packet) { + if (serialized_packet.packet->is_fec_packet()) { + unacked_fec_packets_.insert(make_pair( + serialized_packet.sequence_number, + serialized_packet.retransmittable_frames)); + return; + } + + if (serialized_packet.retransmittable_frames == NULL) { + // Don't track ack/congestion feedback packets. + return; + } + + DCHECK(unacked_packets_.empty() || + unacked_packets_.rbegin()->first < + serialized_packet.sequence_number); + unacked_packets_[serialized_packet.sequence_number] = + serialized_packet.retransmittable_frames; + retransmission_map_[serialized_packet.sequence_number] = + RetransmissionInfo(serialized_packet.sequence_number, + serialized_packet.sequence_number_length); +} + +void QuicSentPacketManager::OnRetransmittedPacket( + QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number) { + DCHECK(ContainsKey(unacked_packets_, old_sequence_number)); + DCHECK(ContainsKey(retransmission_map_, old_sequence_number)); + DCHECK(unacked_packets_.empty() || + unacked_packets_.rbegin()->first < new_sequence_number); + + RetransmissionInfo retransmission_info( + new_sequence_number, GetSequenceNumberLength(old_sequence_number)); + retransmission_info.number_retransmissions = + retransmission_map_[old_sequence_number].number_retransmissions + 1; + retransmission_map_.erase(old_sequence_number); + retransmission_map_[new_sequence_number] = retransmission_info; + + RetransmittableFrames* frames = unacked_packets_[old_sequence_number]; + DCHECK(frames); + unacked_packets_.erase(old_sequence_number); + unacked_packets_[new_sequence_number] = frames; +} + +void QuicSentPacketManager::HandleAckForSentPackets( + const QuicAckFrame& incoming_ack, + SequenceNumberSet* acked_packets) { + // Go through the packets we have not received an ack for and see if this + // incoming_ack shows they've been seen by the peer. + UnackedPacketMap::iterator it = unacked_packets_.begin(); + while (it != unacked_packets_.end()) { + QuicPacketSequenceNumber sequence_number = it->first; + if (sequence_number > helper_->GetPeerLargestObservedPacket()) { + // These are very new sequence_numbers. + break; + } + RetransmittableFrames* unacked = it->second; + if (!IsAwaitingPacket(incoming_ack.received_info, sequence_number)) { + // Packet was acked, so remove it from our unacked packet list. + DVLOG(1) << ENDPOINT <<"Got an ack for packet " << sequence_number; + acked_packets->insert(sequence_number); + delete unacked; + unacked_packets_.erase(it++); + retransmission_map_.erase(sequence_number); + } else { + // This is a packet which we planned on retransmitting and has not been + // seen at the time of this ack being sent out. See if it's our new + // lowest unacked packet. + DVLOG(1) << ENDPOINT << "still missing packet " << sequence_number; + ++it; + // The peer got packets after this sequence number. This is an explicit + // nack. + RetransmissionMap::iterator retransmission_it = + retransmission_map_.find(sequence_number); + if (retransmission_it == retransmission_map_.end()) { + continue; + } + size_t nack_count = ++(retransmission_it->second.number_nacks); + helper_->OnPacketNacked(sequence_number, nack_count); + } + } +} + +void QuicSentPacketManager::HandleAckForSentFecPackets( + const QuicAckFrame& incoming_ack, + SequenceNumberSet* acked_packets) { + UnackedPacketMap::iterator it = unacked_fec_packets_.begin(); + while (it != unacked_fec_packets_.end()) { + QuicPacketSequenceNumber sequence_number = it->first; + if (sequence_number > helper_->GetPeerLargestObservedPacket()) { + break; + } + if (!IsAwaitingPacket(incoming_ack.received_info, sequence_number)) { + DVLOG(1) << ENDPOINT << "Got an ack for fec packet: " << sequence_number; + acked_packets->insert(sequence_number); + unacked_fec_packets_.erase(it++); + } else { + DVLOG(1) << ENDPOINT << "Still missing ack for fec packet: " + << sequence_number; + ++it; + } + } +} + +void QuicSentPacketManager::DiscardPacket( + QuicPacketSequenceNumber sequence_number) { + UnackedPacketMap::iterator unacked_it = + unacked_packets_.find(sequence_number); + if (unacked_it == unacked_packets_.end()) { + // Packet was not meant to be retransmitted. + DCHECK(!ContainsKey(retransmission_map_, sequence_number)); + return; + } + + // Delete the unacked packet. + delete unacked_it->second; + unacked_packets_.erase(unacked_it); + retransmission_map_.erase(sequence_number); +} + +bool QuicSentPacketManager::IsRetransmission( + QuicPacketSequenceNumber sequence_number) const { + RetransmissionMap::const_iterator it = + retransmission_map_.find(sequence_number); + return it != retransmission_map_.end() && + it->second.number_retransmissions > 0; +} + +size_t QuicSentPacketManager::GetRetransmissionCount( + QuicPacketSequenceNumber sequence_number) const { + DCHECK(ContainsKey(retransmission_map_, sequence_number)); + RetransmissionMap::const_iterator it = + retransmission_map_.find(sequence_number); + return it->second.number_retransmissions; +} + +bool QuicSentPacketManager::IsUnacked( + QuicPacketSequenceNumber sequence_number) const { + return ContainsKey(unacked_packets_, sequence_number); +} + +bool QuicSentPacketManager::IsFecUnacked( + QuicPacketSequenceNumber sequence_number) const { + return ContainsKey(unacked_fec_packets_, sequence_number); +} + +const RetransmittableFrames& QuicSentPacketManager::GetRetransmittableFrames( + QuicPacketSequenceNumber sequence_number) const { + DCHECK(ContainsKey(unacked_packets_, sequence_number)); + DCHECK(ContainsKey(retransmission_map_, sequence_number)); + + return *unacked_packets_.find(sequence_number)->second; +} + +QuicSequenceNumberLength QuicSentPacketManager::GetSequenceNumberLength( + QuicPacketSequenceNumber sequence_number) const { + DCHECK(ContainsKey(unacked_packets_, sequence_number)); + DCHECK(ContainsKey(retransmission_map_, sequence_number)); + + return retransmission_map_.find( + sequence_number)->second.sequence_number_length; +} + +bool QuicSentPacketManager::HasUnackedPackets() const { + return !unacked_packets_.empty(); +} + +size_t QuicSentPacketManager::GetNumUnackedPackets() const { + return unacked_packets_.size(); +} + +QuicPacketSequenceNumber +QuicSentPacketManager::GetLeastUnackedSentPacket() const { + if (unacked_packets_.empty()) { + // If there are no unacked packets, set the least unacked packet to + // the sequence number of the next packet sent. + return helper_->GetNextPacketSequenceNumber(); + } + + return unacked_packets_.begin()->first; +} + +SequenceNumberSet QuicSentPacketManager::GetUnackedPackets() const { + SequenceNumberSet unacked_packets; + for (UnackedPacketMap::const_iterator it = unacked_packets_.begin(); + it != unacked_packets_.end(); ++it) { + unacked_packets.insert(it->first); + } + return unacked_packets; +} + +} // namespace net diff --git a/chromium/net/quic/quic_sent_packet_manager.h b/chromium/net/quic/quic_sent_packet_manager.h new file mode 100644 index 00000000000..355ea498b65 --- /dev/null +++ b/chromium/net/quic/quic_sent_packet_manager.h @@ -0,0 +1,140 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_QUIC_QUIC_SENT_PACKET_MANAGER_H_ +#define NET_QUIC_QUIC_SENT_PACKET_MANAGER_H_ + +#include <deque> +#include <list> +#include <map> +#include <queue> +#include <set> +#include <utility> +#include <vector> + +#include "base/containers/hash_tables.h" +#include "net/base/linked_hash_map.h" +#include "net/quic/quic_protocol.h" + +namespace net { + +class NET_EXPORT_PRIVATE QuicSentPacketManager { + public: + class NET_EXPORT_PRIVATE HelperInterface { + public: + virtual QuicPacketSequenceNumber GetPeerLargestObservedPacket() = 0; + virtual QuicPacketSequenceNumber GetNextPacketSequenceNumber() = 0; + + // Called when a packet has been explicitly NACKd + virtual void OnPacketNacked(QuicPacketSequenceNumber sequence_number, + size_t nack_count) = 0; + virtual ~HelperInterface(); + }; + + QuicSentPacketManager(bool is_server, HelperInterface* helper); + virtual ~QuicSentPacketManager(); + + // Called when a new packet is serialized. If the packet contains + // retransmittable data, it will be added to the unacked packet map. + void OnSerializedPacket(const SerializedPacket& serialized_packet); + + // Called when a packet is retransmitted with a new sequence number. + // Replaces the old entry in the unacked packet map with the new + // sequence number. + void OnRetransmittedPacket(QuicPacketSequenceNumber old_sequence_number, + QuicPacketSequenceNumber new_sequence_number); + + // Process the incoming ack looking for newly ack'd data packets. + void HandleAckForSentPackets(const QuicAckFrame& incoming_ack, + SequenceNumberSet* acked_packets); + + // Process the incoming ack looking for newly ack'd FEC packets. + void HandleAckForSentFecPackets(const QuicAckFrame& incoming_ack, + SequenceNumberSet* acked_packets); + + // Discards all information about packet |sequence_number|. + void DiscardPacket(QuicPacketSequenceNumber sequence_number); + + // Returns true if |sequence_number| is a retransmission of a packet. + bool IsRetransmission(QuicPacketSequenceNumber sequence_number) const; + + // Returns the number of times the data in the packet |sequence_number| + // has been transmitted. + size_t GetRetransmissionCount( + QuicPacketSequenceNumber sequence_number) const; + + // Returns true if the non-FEC packet |sequence_number| is unacked. + bool IsUnacked(QuicPacketSequenceNumber sequence_number) const; + + // Returns true if the FEC packet |sequence_number| is unacked. + bool IsFecUnacked(QuicPacketSequenceNumber sequence_number) const; + + // Returns the RetransmittableFrames for |sequence_number|. + const RetransmittableFrames& GetRetransmittableFrames( + QuicPacketSequenceNumber sequence_number) const; + + // Returns the length of the serialized sequence number for + // the packet |sequence_number|. + QuicSequenceNumberLength GetSequenceNumberLength( + QuicPacketSequenceNumber sequence_number) const; + + // Returns true if there are any unacked packets. + bool HasUnackedPackets() const; + + // Returns the number of unacked packets. + size_t GetNumUnackedPackets() const; + + // Returns the smallest sequence number of a sent packet which has not + // been acked by the peer. If all packets have been acked, returns the + // sequence number of the next packet that will be sent. + QuicPacketSequenceNumber GetLeastUnackedSentPacket() const; + + // Returns the set of unacked packet sequence numbers. + SequenceNumberSet GetUnackedPackets() const; + + private: + struct RetransmissionInfo { + RetransmissionInfo() {} + explicit RetransmissionInfo(QuicPacketSequenceNumber sequence_number, + QuicSequenceNumberLength sequence_number_length) + : sequence_number(sequence_number), + sequence_number_length(sequence_number_length), + number_nacks(0), + number_retransmissions(0) { + } + + QuicPacketSequenceNumber sequence_number; + QuicSequenceNumberLength sequence_number_length; + size_t number_nacks; + size_t number_retransmissions; + }; + + typedef linked_hash_map<QuicPacketSequenceNumber, + RetransmittableFrames*> UnackedPacketMap; + typedef base::hash_map<QuicPacketSequenceNumber, + RetransmissionInfo> RetransmissionMap; + + // When new packets are created which may be retransmitted, they are added + // to this map, which contains owning pointers to the contained frames. + UnackedPacketMap unacked_packets_; + + // Pending fec packets that have not been acked yet. These packets need to be + // cleared out of the cgst_window after a timeout since FEC packets are never + // retransmitted. + // TODO(satyamshekhar): What should be the timeout for these packets? + UnackedPacketMap unacked_fec_packets_; + + // Map from sequence number to the retransmission info. + RetransmissionMap retransmission_map_; + + // Tracks if the connection was created by the server. + bool is_server_; + HelperInterface* helper_; + + DISALLOW_COPY_AND_ASSIGN(QuicSentPacketManager); +}; + +} // namespace net + +#endif // NET_QUIC_QUIC_SENT_PACKET_MANAGER_H_ diff --git a/chromium/net/quic/quic_sent_packet_manager_test.cc b/chromium/net/quic/quic_sent_packet_manager_test.cc new file mode 100644 index 00000000000..8f5cb3d19d4 --- /dev/null +++ b/chromium/net/quic/quic_sent_packet_manager_test.cc @@ -0,0 +1,80 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/quic/quic_sent_packet_manager.h" + +#include "net/quic/test_tools/quic_test_utils.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using testing::_; +using testing::Return; +using testing::StrictMock; + +namespace net { +namespace test { +namespace { + +class MockHelper : public QuicSentPacketManager::HelperInterface { + public: + MOCK_METHOD0(GetPeerLargestObservedPacket, QuicPacketSequenceNumber()); + MOCK_METHOD0(GetNextPacketSequenceNumber, QuicPacketSequenceNumber()); + MOCK_METHOD2(OnPacketNacked, void(QuicPacketSequenceNumber sequence_number, + size_t nack_count)); +}; + +class QuicSentPacketManagerTest : public ::testing::Test { + protected: + QuicSentPacketManagerTest() + : manager_(true, &helper_) { + } + + testing::StrictMock<MockHelper> helper_; + QuicSentPacketManager manager_; +}; + +TEST_F(QuicSentPacketManagerTest, GetLeastUnackedSentPacket) { + EXPECT_CALL(helper_, GetNextPacketSequenceNumber()).WillOnce(Return(1u)); + EXPECT_EQ(1u, manager_.GetLeastUnackedSentPacket()); +} + +TEST_F(QuicSentPacketManagerTest, GetLeastUnackedSentPacketUnacked) { + scoped_ptr<QuicPacket> packet(QuicPacket::NewDataPacket( + NULL, 0, false, PACKET_8BYTE_GUID, false, PACKET_6BYTE_SEQUENCE_NUMBER)); + SerializedPacket serialized_packet(1u, PACKET_6BYTE_SEQUENCE_NUMBER, + packet.get(), 7u, + new RetransmittableFrames()); + + manager_.OnSerializedPacket(serialized_packet); + EXPECT_EQ(1u, manager_.GetLeastUnackedSentPacket()); +} + +TEST_F(QuicSentPacketManagerTest, GetLeastUnackedSentPacketUnackedFec) { + scoped_ptr<QuicPacket> packet(QuicPacket::NewFecPacket( + NULL, 0, false, PACKET_8BYTE_GUID, false, PACKET_6BYTE_SEQUENCE_NUMBER)); + SerializedPacket serialized_packet(1u, PACKET_6BYTE_SEQUENCE_NUMBER, + packet.get(), 7u, NULL); + + manager_.OnSerializedPacket(serialized_packet); + // FEC packets do not count as "unacked". + EXPECT_CALL(helper_, GetNextPacketSequenceNumber()).WillOnce(Return(2u)); + EXPECT_EQ(2u, manager_.GetLeastUnackedSentPacket()); +} + +TEST_F(QuicSentPacketManagerTest, GetLeastUnackedSentPacketDiscardUnacked) { + scoped_ptr<QuicPacket> packet(QuicPacket::NewDataPacket( + NULL, 0, false, PACKET_8BYTE_GUID, false, PACKET_6BYTE_SEQUENCE_NUMBER)); + SerializedPacket serialized_packet(1u, PACKET_6BYTE_SEQUENCE_NUMBER, + packet.get(), 7u, + new RetransmittableFrames()); + + manager_.OnSerializedPacket(serialized_packet); + manager_.DiscardPacket(1u); + EXPECT_CALL(helper_, GetNextPacketSequenceNumber()).WillOnce(Return(2u)); + EXPECT_EQ(2u, manager_.GetLeastUnackedSentPacket()); +} + +} // namespace +} // namespace test +} // namespace net diff --git a/chromium/net/quic/quic_session.cc b/chromium/net/quic/quic_session.cc index a07f413cbd2..9389af403ca 100644 --- a/chromium/net/quic/quic_session.cc +++ b/chromium/net/quic/quic_session.cc @@ -18,6 +18,7 @@ using std::vector; namespace net { const size_t kMaxPrematurelyClosedStreamsTracked = 20; +const size_t kMaxZombieStreams = 20; #define ENDPOINT (is_server_ ? "Server: " : " Client: ") @@ -32,12 +33,8 @@ class VisitorShim : public QuicConnectionVisitorInterface { public: explicit VisitorShim(QuicSession* session) : session_(session) {} - virtual bool OnPacket(const IPEndPoint& self_address, - const IPEndPoint& peer_address, - const QuicPacketHeader& header, - const vector<QuicStreamFrame>& frame) OVERRIDE { - bool accepted = session_->OnPacket(self_address, peer_address, header, - frame); + virtual bool OnStreamFrames(const vector<QuicStreamFrame>& frames) OVERRIDE { + bool accepted = session_->OnStreamFrames(frames); session_->PostProcessAfterData(); return accepted; } @@ -51,22 +48,26 @@ class VisitorShim : public QuicConnectionVisitorInterface { session_->PostProcessAfterData(); } - virtual void OnAck(const SequenceNumberSet& acked_packets) OVERRIDE { - session_->OnAck(acked_packets); - session_->PostProcessAfterData(); - } - virtual bool OnCanWrite() OVERRIDE { bool rc = session_->OnCanWrite(); session_->PostProcessAfterData(); return rc; } + virtual void OnSuccessfulVersionNegotiation( + const QuicVersion& version) OVERRIDE { + session_->OnSuccessfulVersionNegotiation(version); + } + virtual void ConnectionClose(QuicErrorCode error, bool from_peer) OVERRIDE { session_->ConnectionClose(error, from_peer); // The session will go away, so don't bother with cleanup. } + virtual bool HasPendingHandshake() const OVERRIDE { + return session_->HasPendingHandshake(); + } + private: QuicSession* session_; }; @@ -83,7 +84,8 @@ QuicSession::QuicSession(QuicConnection* connection, largest_peer_created_stream_id_(0), error_(QUIC_NO_ERROR), goaway_received_(false), - goaway_sent_(false) { + goaway_sent_(false), + has_pending_handshake_(false) { connection_->set_visitor(visitor_shim_.get()); connection_->SetIdleNetworkTimeout(config_.idle_connection_state_lifetime()); @@ -99,16 +101,7 @@ QuicSession::~QuicSession() { STLDeleteValues(&stream_map_); } -bool QuicSession::OnPacket(const IPEndPoint& self_address, - const IPEndPoint& peer_address, - const QuicPacketHeader& header, - const vector<QuicStreamFrame>& frames) { - if (header.public_header.guid != connection()->guid()) { - DLOG(INFO) << ENDPOINT << "Got packet header for invalid GUID: " - << header.public_header.guid; - return false; - } - +bool QuicSession::OnStreamFrames(const vector<QuicStreamFrame>& frames) { for (size_t i = 0; i < frames.size(); ++i) { // TODO(rch) deal with the error case of stream id 0 if (IsClosedStream(frames[i].stream_id)) { @@ -133,9 +126,19 @@ bool QuicSession::OnPacket(const IPEndPoint& self_address, } for (size_t i = 0; i < frames.size(); ++i) { - ReliableQuicStream* stream = GetStream(frames[i].stream_id); - if (stream) { - stream->OnStreamFrame(frames[i]); + QuicStreamId stream_id = frames[i].stream_id; + ReliableQuicStream* stream = GetStream(stream_id); + if (!stream) { + continue; + } + stream->OnStreamFrame(frames[i]); + + // If the stream had been prematurely closed, and the + // headers are now decompressed, then we are finally finished + // with this stream. + if (ContainsKey(zombie_streams_, stream_id) && + stream->headers_decompressed()) { + CloseZombieStream(stream_id); } } @@ -162,6 +165,16 @@ void QuicSession::OnRstStream(const QuicRstStreamFrame& frame) { if (!stream) { return; // Errors are handled by GetStream. } + if (ContainsKey(zombie_streams_, stream->id())) { + // If this was a zombie stream then we close it out now. + CloseZombieStream(stream->id()); + // However, since the headers still have not been decompressed, we want to + // mark it a prematurely closed so that if we ever receive frames + // for this stream we can close the connection. + DCHECK(!stream->headers_decompressed()); + AddPrematurelyClosedStream(frame.stream_id); + return; + } stream->OnStreamReset(frame.error_code); } @@ -171,6 +184,7 @@ void QuicSession::OnGoAway(const QuicGoAwayFrame& frame) { } void QuicSession::ConnectionClose(QuicErrorCode error, bool from_peer) { + DCHECK(!connection_->connected()); if (error_ == QUIC_NO_ERROR) { error_ = error; } @@ -190,13 +204,22 @@ void QuicSession::ConnectionClose(QuicErrorCode error, bool from_peer) { bool QuicSession::OnCanWrite() { // We latch this here rather than doing a traditional loop, because streams // may be modifying the list as we loop. - int remaining_writes = write_blocked_streams_.NumObjects(); + int remaining_writes = write_blocked_streams_.NumBlockedStreams(); while (!connection_->HasQueuedData() && remaining_writes > 0) { - DCHECK(!write_blocked_streams_.IsEmpty()); - ReliableQuicStream* stream = - GetStream(write_blocked_streams_.GetNextBlockedObject()); + DCHECK(write_blocked_streams_.HasWriteBlockedStreams()); + int index = write_blocked_streams_.GetHighestPriorityWriteBlockedList(); + if (index == -1) { + LOG(DFATAL) << "WriteBlockedStream is missing"; + connection_->CloseConnection(QUIC_INTERNAL_ERROR, false); + return true; // We have no write blocked streams. + } + QuicStreamId stream_id = write_blocked_streams_.PopFront(index); + if (stream_id == kCryptoStreamId) { + has_pending_handshake_ = false; // We just popped it. + } + ReliableQuicStream* stream = GetStream(stream_id); if (stream != NULL) { // If the stream can't write all bytes, it'll re-add itself to the blocked // list. @@ -205,20 +228,25 @@ bool QuicSession::OnCanWrite() { --remaining_writes; } - return write_blocked_streams_.IsEmpty(); + return !write_blocked_streams_.HasWriteBlockedStreams(); +} + +bool QuicSession::HasPendingHandshake() const { + return has_pending_handshake_; } -QuicConsumedData QuicSession::WriteData(QuicStreamId id, - StringPiece data, - QuicStreamOffset offset, - bool fin) { - return connection_->SendStreamData(id, data, offset, fin); +QuicConsumedData QuicSession::WritevData(QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin) { + return connection_->SendvStreamData(id, iov, iov_count, offset, fin); } void QuicSession::SendRstStream(QuicStreamId id, QuicRstStreamErrorCode error) { connection_->SendRstStream(id, error); - CloseStream(id); + CloseStreamInner(id, true); } void QuicSession::SendGoAway(QuicErrorCode error_code, const string& reason) { @@ -227,6 +255,11 @@ void QuicSession::SendGoAway(QuicErrorCode error_code, const string& reason) { } void QuicSession::CloseStream(QuicStreamId stream_id) { + CloseStreamInner(stream_id, false); +} + +void QuicSession::CloseStreamInner(QuicStreamId stream_id, + bool locally_reset) { DLOG(INFO) << ENDPOINT << "Closing stream " << stream_id; ReliableStreamMap::iterator it = stream_map_.find(stream_id); @@ -235,18 +268,65 @@ void QuicSession::CloseStream(QuicStreamId stream_id) { return; } ReliableQuicStream* stream = it->second; - if (!stream->headers_decompressed()) { - if (prematurely_closed_streams_.size() == - kMaxPrematurelyClosedStreamsTracked) { - prematurely_closed_streams_.erase(prematurely_closed_streams_.begin()); + if (connection_->connected() && !stream->headers_decompressed()) { + // If the stream is being closed locally (for example a client cancelling + // a request before receiving the response) then we need to make sure that + // we keep the stream alive long enough to process any response or + // RST_STREAM frames. + if (locally_reset && !is_server_) { + AddZombieStream(stream_id); + return; } - prematurely_closed_streams_.insert(make_pair(stream->id(), true)); + + // This stream has been closed before the headers were decompressed. + // This might cause problems with head of line blocking of headers. + // If the peer sent headers which were lost but we now close the stream + // we will never be able to decompress headers for other streams. + // To deal with this, we keep track of streams which have been closed + // prematurely. If we ever receive data frames for this steam, then we + // know there actually has been a problem and we close the connection. + AddPrematurelyClosedStream(stream->id()); } closed_streams_.push_back(it->second); + if (ContainsKey(zombie_streams_, stream->id())) { + zombie_streams_.erase(stream->id()); + } stream_map_.erase(it); stream->OnClose(); } +void QuicSession::AddZombieStream(QuicStreamId stream_id) { + if (zombie_streams_.size() == kMaxZombieStreams) { + QuicStreamId oldest_zombie_stream_id = zombie_streams_.begin()->first; + CloseZombieStream(oldest_zombie_stream_id); + // However, since the headers still have not been decompressed, we want to + // mark it a prematurely closed so that if we ever receive frames + // for this stream we can close the connection. + AddPrematurelyClosedStream(oldest_zombie_stream_id); + } + zombie_streams_.insert(make_pair(stream_id, true)); +} + +void QuicSession::CloseZombieStream(QuicStreamId stream_id) { + DCHECK(ContainsKey(zombie_streams_, stream_id)); + zombie_streams_.erase(stream_id); + ReliableQuicStream* stream = GetStream(stream_id); + if (!stream) { + return; + } + stream_map_.erase(stream_id); + stream->OnClose(); + closed_streams_.push_back(stream); +} + +void QuicSession::AddPrematurelyClosedStream(QuicStreamId stream_id) { + if (prematurely_closed_streams_.size() == + kMaxPrematurelyClosedStreamsTracked) { + prematurely_closed_streams_.erase(prematurely_closed_streams_.begin()); + } + prematurely_closed_streams_.insert(make_pair(stream_id, true)); +} + bool QuicSession::IsEncryptionEstablished() { return GetCryptoStream()->encryption_established(); } @@ -283,6 +363,14 @@ void QuicSession::OnCryptoHandshakeEvent(CryptoHandshakeEvent event) { } } +void QuicSession::OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message) { +} + +void QuicSession::OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message) { +} + QuicConfig* QuicSession::config() { return &config_; } @@ -290,7 +378,7 @@ QuicConfig* QuicSession::config() { void QuicSession::ActivateStream(ReliableQuicStream* stream) { DLOG(INFO) << ENDPOINT << "num_streams: " << stream_map_.size() << ". activating " << stream->id(); - DCHECK(stream_map_.count(stream->id()) == 0); + DCHECK_EQ(stream_map_.count(stream->id()), 0u); stream_map_[stream->id()] = stream; } @@ -343,12 +431,13 @@ ReliableQuicStream* QuicSession::GetIncomingReliableStream( connection()->SendConnectionClose(QUIC_INVALID_STREAM_ID); return NULL; } - if (largest_peer_created_stream_id_ != 0) { - for (QuicStreamId id = largest_peer_created_stream_id_ + 2; - id < stream_id; - id += 2) { - implicitly_created_streams_.insert(id); - } + if (largest_peer_created_stream_id_ == 0) { + largest_peer_created_stream_id_= 1; + } + for (QuicStreamId id = largest_peer_created_stream_id_ + 2; + id < stream_id; + id += 2) { + implicitly_created_streams_.insert(id); } largest_peer_created_stream_id_ = stream_id; } @@ -365,7 +454,10 @@ bool QuicSession::IsClosedStream(QuicStreamId id) { if (id == kCryptoStreamId) { return false; } - if (stream_map_.count(id) != 0) { + if (ContainsKey(zombie_streams_, id)) { + return true; + } + if (ContainsKey(stream_map_, id)) { // Stream is active return false; } @@ -381,11 +473,20 @@ bool QuicSession::IsClosedStream(QuicStreamId id) { } size_t QuicSession::GetNumOpenStreams() const { - return stream_map_.size() + implicitly_created_streams_.size(); + return stream_map_.size() + implicitly_created_streams_.size() - + zombie_streams_.size(); } -void QuicSession::MarkWriteBlocked(QuicStreamId id) { - write_blocked_streams_.AddBlockedObject(id); +void QuicSession::MarkWriteBlocked(QuicStreamId id, QuicPriority priority) { + if (id == kCryptoStreamId) { + DCHECK(!has_pending_handshake_); + has_pending_handshake_ = true; + // TODO(jar): Be sure to use the highest priority for the crypto stream, + // perhaps by adding a "special" priority for it that is higher than + // kHighestPriority. + priority = kHighestPriority; + } + write_blocked_streams_.PushBack(id, priority); } void QuicSession::MarkDecompressionBlocked(QuicHeaderId header_id, diff --git a/chromium/net/quic/quic_session.h b/chromium/net/quic/quic_session.h index 9f7aceb7581..b58feb27426 100644 --- a/chromium/net/quic/quic_session.h +++ b/chromium/net/quic/quic_session.h @@ -13,7 +13,6 @@ #include "base/containers/hash_tables.h" #include "net/base/ip_endpoint.h" #include "net/base/linked_hash_map.h" -#include "net/quic/blocked_list.h" #include "net/quic/quic_connection.h" #include "net/quic/quic_crypto_stream.h" #include "net/quic/quic_packet_creator.h" @@ -21,6 +20,7 @@ #include "net/quic/quic_spdy_compressor.h" #include "net/quic/quic_spdy_decompressor.h" #include "net/quic/reliable_quic_stream.h" +#include "net/spdy/write_blocked_list.h" namespace net { @@ -59,26 +59,28 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { virtual ~QuicSession(); // QuicConnectionVisitorInterface methods: - virtual bool OnPacket(const IPEndPoint& self_address, - const IPEndPoint& peer_address, - const QuicPacketHeader& header, - const std::vector<QuicStreamFrame>& frame) OVERRIDE; + virtual bool OnStreamFrames( + const std::vector<QuicStreamFrame>& frames) OVERRIDE; virtual void OnRstStream(const QuicRstStreamFrame& frame) OVERRIDE; virtual void OnGoAway(const QuicGoAwayFrame& frame) OVERRIDE; virtual void ConnectionClose(QuicErrorCode error, bool from_peer) OVERRIDE; + virtual void OnSuccessfulVersionNegotiation( + const QuicVersion& version) OVERRIDE{} // Not needed for HTTP. - virtual void OnAck(const SequenceNumberSet& acked_packets) OVERRIDE {} virtual bool OnCanWrite() OVERRIDE; + virtual bool HasPendingHandshake() const OVERRIDE; // Called by streams when they want to write data to the peer. // Returns a pair with the number of bytes consumed from data, and a boolean // indicating if the fin bit was consumed. This does not indicate the data // has been sent on the wire: it may have been turned into a packet and queued // if the socket was unexpectedly blocked. - virtual QuicConsumedData WriteData(QuicStreamId id, - base::StringPiece data, - QuicStreamOffset offset, - bool fin); + virtual QuicConsumedData WritevData(QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin); + // Called by streams when they want to close the stream in both directions. virtual void SendRstStream(QuicStreamId id, QuicRstStreamErrorCode error); @@ -106,6 +108,14 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { // Servers will simply call it once with HANDSHAKE_CONFIRMED. virtual void OnCryptoHandshakeEvent(CryptoHandshakeEvent event); + // Called by the QuicCryptoStream when a handshake message is sent. + virtual void OnCryptoHandshakeMessageSent( + const CryptoHandshakeMessage& message); + + // Called by the QuicCryptoStream when a handshake message is received. + virtual void OnCryptoHandshakeMessageReceived( + const CryptoHandshakeMessage& message); + // Returns mutable config for this session. Returned config is owned // by QuicSession. QuicConfig* config(); @@ -129,7 +139,7 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { // been implicitly created. virtual size_t GetNumOpenStreams() const; - void MarkWriteBlocked(QuicStreamId id); + void MarkWriteBlocked(QuicStreamId id, QuicPriority priority); // Marks that |stream_id| is blocked waiting to decompress the // headers identified by |decompression_id|. @@ -203,6 +213,24 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { typedef base::hash_map<QuicStreamId, ReliableQuicStream*> ReliableStreamMap; + // Performs the work required to close |stream_id|. If |locally_reset| + // then the stream has been reset by this endpoint, not by the peer. This + // means the stream may become a zombie stream which needs to stay + // around until headers have been decompressed. + void CloseStreamInner(QuicStreamId stream_id, bool locally_reset); + + // Adds |stream_id| to the zobmie stream map, closing the oldest + // zombie stream if the set is full. + void AddZombieStream(QuicStreamId stream_id); + + // Closes the zombie stream |stream_id| and removes it from the zombie + // stream map. + void CloseZombieStream(QuicStreamId stream_id); + + // Adds |stream_id| to the prematurely closed stream map, removing the + // oldest prematurely closed stream if the set is full. + void AddPrematurelyClosedStream(QuicStreamId stream_id); + scoped_ptr<QuicConnection> connection_; // Tracks the last 20 streams which closed without decompressing headers. @@ -210,6 +238,12 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { // Ideally this would be a linked_hash_set as the boolean is unused. linked_hash_map<QuicStreamId, bool> prematurely_closed_streams_; + // Streams which have been locally reset before decompressing headers + // from the peer. These streams need to stay open long enough to + // process any headers from the peer. + // Ideally this would be a linked_hash_set as the boolean is unused. + linked_hash_map<QuicStreamId, bool> zombie_streams_; + // A shim to stand between the connection and the session, to handle stream // deletions. scoped_ptr<VisitorShim> visitor_shim_; @@ -234,7 +268,7 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { base::hash_set<QuicStreamId> implicitly_created_streams_; // A list of streams which need to write more data. - BlockedList<QuicStreamId> write_blocked_streams_; + WriteBlockedList<QuicStreamId> write_blocked_streams_; // A map of headers waiting to be compressed, and the streams // they are associated with. @@ -250,6 +284,9 @@ class NET_EXPORT_PRIVATE QuicSession : public QuicConnectionVisitorInterface { // Whether a GoAway has been sent. bool goaway_sent_; + // Indicate if there is pending data for the crypto stream. + bool has_pending_handshake_; + DISALLOW_COPY_AND_ASSIGN(QuicSession); }; diff --git a/chromium/net/quic/quic_session_test.cc b/chromium/net/quic/quic_session_test.cc index e417c436840..e90b20f81c6 100644 --- a/chromium/net/quic/quic_session_test.cc +++ b/chromium/net/quic/quic_session_test.cc @@ -13,6 +13,7 @@ #include "net/quic/quic_protocol.h" #include "net/quic/test_tools/quic_connection_peer.h" #include "net/quic/test_tools/quic_test_utils.h" +#include "net/quic/test_tools/reliable_quic_stream_peer.h" #include "net/spdy/spdy_framer.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -22,11 +23,15 @@ using std::set; using std::vector; using testing::_; using testing::InSequence; +using testing::InvokeWithoutArgs; +using testing::StrictMock; namespace net { namespace test { namespace { +const QuicPriority kSomeMiddlePriority = 2; + class TestCryptoStream : public QuicCryptoStream { public: explicit TestCryptoStream(QuicSession* session) @@ -45,6 +50,8 @@ class TestCryptoStream : public QuicCryptoStream { EXPECT_EQ(QUIC_NO_ERROR, error); session()->OnCryptoHandshakeEvent(QuicSession::HANDSHAKE_CONFIRMED); } + + MOCK_METHOD0(OnCanWrite, void()); }; class TestStream : public ReliableQuicStream { @@ -62,6 +69,23 @@ class TestStream : public ReliableQuicStream { MOCK_METHOD0(OnCanWrite, void()); }; +// Poor man's functor for use as callback in a mock. +class StreamBlocker { + public: + StreamBlocker(QuicSession* session, QuicStreamId stream_id) + : session_(session), + stream_id_(stream_id) { + } + + void MarkWriteBlocked() { + session_->MarkWriteBlocked(stream_id_, kSomeMiddlePriority); + } + + private: + QuicSession* const session_; + const QuicStreamId stream_id_; +}; + class TestSession : public QuicSession { public: TestSession(QuicConnection* connection, bool is_server) @@ -69,7 +93,7 @@ class TestSession : public QuicSession { crypto_stream_(this) { } - virtual QuicCryptoStream* GetCryptoStream() OVERRIDE { + virtual TestCryptoStream* GetCryptoStream() OVERRIDE { return &crypto_stream_; } @@ -91,11 +115,6 @@ class TestSession : public QuicSession { return QuicSession::GetIncomingReliableStream(stream_id); } - // Helper method for gmock - void MarkTwoWriteBlocked() { - this->MarkWriteBlocked(2); - } - TestCryptoStream crypto_stream_; }; @@ -105,6 +124,33 @@ class QuicSessionTest : public ::testing::Test { : guid_(1), connection_(new MockConnection(guid_, IPEndPoint(), false)), session_(connection_, true) { + headers_[":host"] = "www.google.com"; + headers_[":path"] = "/index.hml"; + headers_[":scheme"] = "http"; + headers_["cookie"] = + "__utma=208381060.1228362404.1372200928.1372200928.1372200928.1; " + "__utmc=160408618; " + "GX=DQAAAOEAAACWJYdewdE9rIrW6qw3PtVi2-d729qaa-74KqOsM1NVQblK4VhX" + "hoALMsy6HOdDad2Sz0flUByv7etmo3mLMidGrBoljqO9hSVA40SLqpG_iuKKSHX" + "RW3Np4bq0F0SDGDNsW0DSmTS9ufMRrlpARJDS7qAI6M3bghqJp4eABKZiRqebHT" + "pMU-RXvTI5D5oCF1vYxYofH_l1Kviuiy3oQ1kS1enqWgbhJ2t61_SNdv-1XJIS0" + "O3YeHLmVCs62O6zp89QwakfAWK9d3IDQvVSJzCQsvxvNIvaZFa567MawWlXg0Rh" + "1zFMi5vzcns38-8_Sns; " + "GA=v*2%2Fmem*57968640*47239936%2Fmem*57968640*47114716%2Fno-nm-" + "yj*15%2Fno-cc-yj*5%2Fpc-ch*133685%2Fpc-s-cr*133947%2Fpc-s-t*1339" + "47%2Fno-nm-yj*4%2Fno-cc-yj*1%2Fceft-as*1%2Fceft-nqas*0%2Fad-ra-c" + "v_p%2Fad-nr-cv_p-f*1%2Fad-v-cv_p*859%2Fad-ns-cv_p-f*1%2Ffn-v-ad%" + "2Fpc-t*250%2Fpc-cm*461%2Fpc-s-cr*722%2Fpc-s-t*722%2Fau_p*4" + "SICAID=AJKiYcHdKgxum7KMXG0ei2t1-W4OD1uW-ecNsCqC0wDuAXiDGIcT_HA2o1" + "3Rs1UKCuBAF9g8rWNOFbxt8PSNSHFuIhOo2t6bJAVpCsMU5Laa6lewuTMYI8MzdQP" + "ARHKyW-koxuhMZHUnGBJAM1gJODe0cATO_KGoX4pbbFxxJ5IicRxOrWK_5rU3cdy6" + "edlR9FsEdH6iujMcHkbE5l18ehJDwTWmBKBzVD87naobhMMrF6VvnDGxQVGp9Ir_b" + "Rgj3RWUoPumQVCxtSOBdX0GlJOEcDTNCzQIm9BSfetog_eP_TfYubKudt5eMsXmN6" + "QnyXHeGeK2UINUzJ-D30AFcpqYgH9_1BvYSpi7fc7_ydBU8TaD8ZRxvtnzXqj0RfG" + "tuHghmv3aD-uzSYJ75XDdzKdizZ86IG6Fbn1XFhYZM-fbHhm3mVEXnyRW4ZuNOLFk" + "Fas6LMcVC6Q8QLlHYbXBpdNFuGbuZGUnav5C-2I_-46lL0NGg3GewxGKGHvHEfoyn" + "EFFlEYHsBQ98rXImL8ySDycdLEFvBPdtctPmWCfTxwmoSMLHU2SCVDhbqMWU5b0yr" + "JBCScs_ejbKaqBDoB7ZGxTvqlrB__2ZmnHHjCr8RgMRtKNtIeuZAo "; } void CheckClosedStreams() { @@ -125,11 +171,14 @@ class QuicSessionTest : public ::testing::Test { QuicGuid guid_; MockConnection* connection_; TestSession session_; - QuicConnectionVisitorInterface* visitor_; - hash_map<QuicStreamId, ReliableQuicStream*>* streams_; set<QuicStreamId> closed_streams_; + SpdyHeaderBlock headers_; }; +TEST_F(QuicSessionTest, PeerAddress) { + EXPECT_EQ(IPEndPoint(), session_.peer_address()); +} + TEST_F(QuicSessionTest, IsCryptoHandshakeConfirmed) { EXPECT_FALSE(session_.IsCryptoHandshakeConfirmed()); CryptoHandshakeMessage message; @@ -144,11 +193,22 @@ TEST_F(QuicSessionTest, IsClosedStreamDefault) { } } +TEST_F(QuicSessionTest, ImplicitlyCreatedStreams) { + ASSERT_TRUE(session_.GetIncomingReliableStream(7) != NULL); + // Both 3 and 5 should be implicitly created. + EXPECT_FALSE(session_.IsClosedStream(3)); + EXPECT_FALSE(session_.IsClosedStream(5)); + ASSERT_TRUE(session_.GetIncomingReliableStream(5) != NULL); + ASSERT_TRUE(session_.GetIncomingReliableStream(3) != NULL); +} + TEST_F(QuicSessionTest, IsClosedStreamLocallyCreated) { TestStream* stream2 = session_.CreateOutgoingReliableStream(); EXPECT_EQ(2u, stream2->id()); + ReliableQuicStreamPeer::SetHeadersDecompressed(stream2, true); TestStream* stream4 = session_.CreateOutgoingReliableStream(); EXPECT_EQ(4u, stream4->id()); + ReliableQuicStreamPeer::SetHeadersDecompressed(stream4, true); CheckClosedStreams(); CloseStream(4); @@ -158,15 +218,18 @@ TEST_F(QuicSessionTest, IsClosedStreamLocallyCreated) { } TEST_F(QuicSessionTest, IsClosedStreamPeerCreated) { - session_.GetIncomingReliableStream(3); - session_.GetIncomingReliableStream(5); + ReliableQuicStream* stream3 = session_.GetIncomingReliableStream(3); + ReliableQuicStreamPeer::SetHeadersDecompressed(stream3, true); + ReliableQuicStream* stream5 = session_.GetIncomingReliableStream(5); + ReliableQuicStreamPeer::SetHeadersDecompressed(stream5, true); CheckClosedStreams(); CloseStream(3); CheckClosedStreams(); CloseStream(5); // Create stream id 9, and implicitly 7 - session_.GetIncomingReliableStream(9); + ReliableQuicStream* stream9 = session_.GetIncomingReliableStream(9); + ReliableQuicStreamPeer::SetHeadersDecompressed(stream9, true); CheckClosedStreams(); // Close 9, but make sure 7 is still not closed CloseStream(9); @@ -194,29 +257,77 @@ TEST_F(QuicSessionTest, OnCanWrite) { TestStream* stream4 = session_.CreateOutgoingReliableStream(); TestStream* stream6 = session_.CreateOutgoingReliableStream(); - session_.MarkWriteBlocked(2); - session_.MarkWriteBlocked(6); - session_.MarkWriteBlocked(4); + session_.MarkWriteBlocked(stream2->id(), kSomeMiddlePriority); + session_.MarkWriteBlocked(stream6->id(), kSomeMiddlePriority); + session_.MarkWriteBlocked(stream4->id(), kSomeMiddlePriority); InSequence s; + StreamBlocker stream2_blocker(&session_, stream2->id()); EXPECT_CALL(*stream2, OnCanWrite()).WillOnce( // Reregister, to test the loop limit. - testing::InvokeWithoutArgs(&session_, &TestSession::MarkTwoWriteBlocked)); + InvokeWithoutArgs(&stream2_blocker, &StreamBlocker::MarkWriteBlocked)); EXPECT_CALL(*stream6, OnCanWrite()); EXPECT_CALL(*stream4, OnCanWrite()); EXPECT_FALSE(session_.OnCanWrite()); } +TEST_F(QuicSessionTest, BufferedHandshake) { + EXPECT_FALSE(session_.HasPendingHandshake()); // Default value. + + // Test that blocking other streams does not change our status. + TestStream* stream2 = session_.CreateOutgoingReliableStream(); + StreamBlocker stream2_blocker(&session_, stream2->id()); + stream2_blocker.MarkWriteBlocked(); + EXPECT_FALSE(session_.HasPendingHandshake()); + + TestStream* stream3 = session_.CreateOutgoingReliableStream(); + StreamBlocker stream3_blocker(&session_, stream3->id()); + stream3_blocker.MarkWriteBlocked(); + EXPECT_FALSE(session_.HasPendingHandshake()); + + // Blocking (due to buffering of) the Crypto stream is detected. + session_.MarkWriteBlocked(kCryptoStreamId, kSomeMiddlePriority); + EXPECT_TRUE(session_.HasPendingHandshake()); + + TestStream* stream4 = session_.CreateOutgoingReliableStream(); + StreamBlocker stream4_blocker(&session_, stream4->id()); + stream4_blocker.MarkWriteBlocked(); + EXPECT_TRUE(session_.HasPendingHandshake()); + + InSequence s; + // Force most streams to re-register, which is common scenario when we block + // the Crypto stream, and only the crypto stream can "really" write. + + // Due to prioritization, we *should* be asked to write the crypto stream + // first. + // Don't re-register the crypto stream (which signals complete writing). + TestCryptoStream* crypto_stream = session_.GetCryptoStream(); + EXPECT_CALL(*crypto_stream, OnCanWrite()); + + // Re-register all other streams, to show they weren't able to proceed. + EXPECT_CALL(*stream2, OnCanWrite()).WillOnce( + InvokeWithoutArgs(&stream2_blocker, &StreamBlocker::MarkWriteBlocked)); + + EXPECT_CALL(*stream3, OnCanWrite()).WillOnce( + InvokeWithoutArgs(&stream3_blocker, &StreamBlocker::MarkWriteBlocked)); + + EXPECT_CALL(*stream4, OnCanWrite()).WillOnce( + InvokeWithoutArgs(&stream4_blocker, &StreamBlocker::MarkWriteBlocked)); + + EXPECT_FALSE(session_.OnCanWrite()); + EXPECT_FALSE(session_.HasPendingHandshake()); // Crypto stream wrote. +} + TEST_F(QuicSessionTest, OnCanWriteWithClosedStream) { TestStream* stream2 = session_.CreateOutgoingReliableStream(); TestStream* stream4 = session_.CreateOutgoingReliableStream(); - session_.CreateOutgoingReliableStream(); // stream 6 + TestStream* stream6 = session_.CreateOutgoingReliableStream(); - session_.MarkWriteBlocked(2); - session_.MarkWriteBlocked(6); - session_.MarkWriteBlocked(4); - CloseStream(6); + session_.MarkWriteBlocked(stream2->id(), kSomeMiddlePriority); + session_.MarkWriteBlocked(stream6->id(), kSomeMiddlePriority); + session_.MarkWriteBlocked(stream4->id(), kSomeMiddlePriority); + CloseStream(stream6->id()); InSequence s; EXPECT_CALL(*stream2, OnCanWrite()); @@ -227,10 +338,6 @@ TEST_F(QuicSessionTest, OnCanWriteWithClosedStream) { // Regression test for http://crbug.com/248737 TEST_F(QuicSessionTest, OutOfOrderHeaders) { QuicSpdyCompressor compressor; - SpdyHeaderBlock headers; - headers[":host"] = "www.google.com"; - headers[":path"] = "/index.hml"; - headers[":scheme"] = "http"; vector<QuicStreamFrame> frames; QuicPacketHeader header; header.public_header.guid = session_.guid(); @@ -241,24 +348,24 @@ TEST_F(QuicSessionTest, OutOfOrderHeaders) { stream4->CloseWriteSide(); // Create frame with headers for stream2. - string compressed_headers1 = compressor.CompressHeaders(headers); + string compressed_headers1 = compressor.CompressHeaders(headers_); QuicStreamFrame frame1(stream2->id(), false, 0, compressed_headers1); // Create frame with headers for stream4. - string compressed_headers2 = compressor.CompressHeaders(headers); + string compressed_headers2 = compressor.CompressHeaders(headers_); QuicStreamFrame frame2(stream4->id(), true, 0, compressed_headers2); // Process the second frame first. This will cause the headers to // be queued up and processed after the first frame is processed. frames.push_back(frame2); - session_.OnPacket(IPEndPoint(), IPEndPoint(), header, frames); + session_.OnStreamFrames(frames); // Process the first frame, and un-cork the buffered headers. frames[0] = frame1; - session_.OnPacket(IPEndPoint(), IPEndPoint(), header, frames); + session_.OnStreamFrames(frames); // Ensure that the streams actually close and we don't DCHECK. - session_.ConnectionClose(QUIC_CONNECTION_TIMED_OUT, true); + connection_->CloseConnection(QUIC_CONNECTION_TIMED_OUT, true); } TEST_F(QuicSessionTest, SendGoAway) { @@ -282,6 +389,70 @@ TEST_F(QuicSessionTest, IncreasedTimeoutAfterCryptoHandshake) { QuicConnectionPeer::GetNetworkTimeout(connection_).ToSeconds()); } +TEST_F(QuicSessionTest, ZombieStream) { + StrictMock<MockConnection>* connection = + new StrictMock<MockConnection>(guid_, IPEndPoint(), false); + TestSession session(connection, /*is_server=*/ false); + + TestStream* stream3 = session.CreateOutgoingReliableStream(); + EXPECT_EQ(3u, stream3->id()); + TestStream* stream5 = session.CreateOutgoingReliableStream(); + EXPECT_EQ(5u, stream5->id()); + EXPECT_EQ(2u, session.GetNumOpenStreams()); + + // Reset the stream, but since the headers have not been decompressed + // it will become a zombie and will continue to process data + // until the headers are decompressed. + EXPECT_CALL(*connection, SendRstStream(3, QUIC_STREAM_CANCELLED)); + session.SendRstStream(3, QUIC_STREAM_CANCELLED); + + EXPECT_EQ(1u, session.GetNumOpenStreams()); + + vector<QuicStreamFrame> frames; + QuicPacketHeader header; + header.public_header.guid = session_.guid(); + + // Create frame with headers for stream2. + QuicSpdyCompressor compressor; + string compressed_headers1 = compressor.CompressHeaders(headers_); + QuicStreamFrame frame1(stream3->id(), false, 0, compressed_headers1); + + // Process the second frame first. This will cause the headers to + // be queued up and processed after the first frame is processed. + frames.push_back(frame1); + EXPECT_FALSE(stream3->headers_decompressed()); + + session.OnStreamFrames(frames); + EXPECT_EQ(1u, session.GetNumOpenStreams()); + + EXPECT_TRUE(connection->connected()); +} + +TEST_F(QuicSessionTest, ZombieStreamConnectionClose) { + StrictMock<MockConnection>* connection = + new StrictMock<MockConnection>(guid_, IPEndPoint(), false); + TestSession session(connection, /*is_server=*/ false); + + TestStream* stream3 = session.CreateOutgoingReliableStream(); + EXPECT_EQ(3u, stream3->id()); + TestStream* stream5 = session.CreateOutgoingReliableStream(); + EXPECT_EQ(5u, stream5->id()); + EXPECT_EQ(2u, session.GetNumOpenStreams()); + + stream3->CloseWriteSide(); + // Reset the stream, but since the headers have not been decompressed + // it will become a zombie and will continue to process data + // until the headers are decompressed. + EXPECT_CALL(*connection, SendRstStream(3, QUIC_STREAM_CANCELLED)); + session.SendRstStream(3, QUIC_STREAM_CANCELLED); + + EXPECT_EQ(1u, session.GetNumOpenStreams()); + + connection->CloseConnection(QUIC_CONNECTION_TIMED_OUT, false); + + EXPECT_EQ(0u, session.GetNumOpenStreams()); +} + } // namespace } // namespace test } // namespace net diff --git a/chromium/net/quic/quic_spdy_compressor.cc b/chromium/net/quic/quic_spdy_compressor.cc index 7efd45ce76f..6681493b66a 100644 --- a/chromium/net/quic/quic_spdy_compressor.cc +++ b/chromium/net/quic/quic_spdy_compressor.cc @@ -20,8 +20,21 @@ QuicSpdyCompressor::QuicSpdyCompressor() QuicSpdyCompressor::~QuicSpdyCompressor() { } +string QuicSpdyCompressor::CompressHeadersWithPriority( + QuicPriority priority, + const SpdyHeaderBlock& headers) { + return CompressHeadersInternal(priority, headers, true); +} + string QuicSpdyCompressor::CompressHeaders( const SpdyHeaderBlock& headers) { + return CompressHeadersInternal(0, headers, false); +} + +string QuicSpdyCompressor::CompressHeadersInternal( + QuicPriority priority, + const SpdyHeaderBlock& headers, + bool write_priority) { // TODO(rch): Modify the SpdyFramer to expose a // CreateCompressedHeaderBlock method, or some such. SpdyStreamId stream_id = 3; // unused. @@ -34,12 +47,19 @@ string QuicSpdyCompressor::CompressHeaders( string serialized = string(frame->data() + header_frame_prefix_len, frame->size() - header_frame_prefix_len); uint32 serialized_len = serialized.length(); + char priority_str[sizeof(priority)]; + memcpy(&priority_str, &priority, sizeof(priority)); char id_str[sizeof(header_sequence_id_)]; memcpy(&id_str, &header_sequence_id_, sizeof(header_sequence_id_)); char len_str[sizeof(serialized_len)]; memcpy(&len_str, &serialized_len, sizeof(serialized_len)); string compressed; - compressed.reserve(arraysize(id_str) + arraysize(len_str) + serialized_len); + int priority_len = write_priority ? arraysize(priority_str) : 0; + compressed.reserve( + priority_len + arraysize(id_str) + arraysize(len_str) + serialized_len); + if (write_priority) { + compressed.append(priority_str, arraysize(priority_str)); + } compressed.append(id_str, arraysize(id_str)); compressed.append(len_str, arraysize(len_str)); compressed.append(serialized); diff --git a/chromium/net/quic/quic_spdy_compressor.h b/chromium/net/quic/quic_spdy_compressor.h index c88c47eae7e..53a70603fea 100644 --- a/chromium/net/quic/quic_spdy_compressor.h +++ b/chromium/net/quic/quic_spdy_compressor.h @@ -24,9 +24,19 @@ class NET_EXPORT_PRIVATE QuicSpdyCompressor { QuicSpdyCompressor(); ~QuicSpdyCompressor(); + // Returns a string comprised of [header_sequence_id, compressed_headers]. std::string CompressHeaders(const SpdyHeaderBlock& headers); + // Returns a string comprised of + // [priority, header_sequence_id, compressed_headers] + std::string CompressHeadersWithPriority(QuicPriority priority, + const SpdyHeaderBlock& headers); + private: + std::string CompressHeadersInternal(QuicPriority priority, + const SpdyHeaderBlock& headers, + bool write_priority); + SpdyFramer spdy_framer_; QuicHeaderId header_sequence_id_; diff --git a/chromium/net/quic/quic_stream_factory.cc b/chromium/net/quic/quic_stream_factory.cc index fba7f0b7dec..4e1a8eddc20 100644 --- a/chromium/net/quic/quic_stream_factory.cc +++ b/chromium/net/quic/quic_stream_factory.cc @@ -16,6 +16,7 @@ #include "net/cert/cert_verifier.h" #include "net/dns/host_resolver.h" #include "net/dns/single_request_host_resolver.h" +#include "net/http/http_server_properties.h" #include "net/quic/crypto/proof_verifier_chromium.h" #include "net/quic/crypto/quic_random.h" #include "net/quic/quic_client_session.h" @@ -147,9 +148,10 @@ void QuicStreamFactory::Job::OnIOComplete(int rv) { int QuicStreamFactory::Job::DoResolveHost() { io_state_ = STATE_RESOLVE_HOST_COMPLETE; return host_resolver_.Resolve( - HostResolver::RequestInfo(host_port_proxy_pair_.first), &address_list_, - base::Bind(&QuicStreamFactory::Job::OnIOComplete, - base::Unretained(this)), + HostResolver::RequestInfo(host_port_proxy_pair_.first), + DEFAULT_PRIORITY, + &address_list_, + base::Bind(&QuicStreamFactory::Job::OnIOComplete, base::Unretained(this)), net_log_); } @@ -157,16 +159,6 @@ int QuicStreamFactory::Job::DoResolveHostComplete(int rv) { if (rv != OK) return rv; - // TODO(rch): remove this code! - AddressList::iterator it = address_list_.begin(); - while (it != address_list_.end()) { - if (it->GetFamily() == ADDRESS_FAMILY_IPV6) { - it = address_list_.erase(it); - } else { - it++; - } - } - DCHECK(!factory_->HasActiveSession(host_port_proxy_pair_)); io_state_ = STATE_CONNECT; return OK; @@ -226,6 +218,7 @@ int QuicStreamFactory::Job::DoConnect() { cert_verifier_, address_list_, net_log_); session_->StartReading(); int rv = session_->CryptoConnect( + factory_->require_confirmation() || is_https_, base::Bind(&QuicStreamFactory::Job::OnIOComplete, base::Unretained(this))); return rv; @@ -244,11 +237,14 @@ int QuicStreamFactory::Job::DoConnectComplete(int rv) { QuicStreamFactory::QuicStreamFactory( HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, + base::WeakPtr<HttpServerProperties> http_server_properties, QuicCryptoClientStreamFactory* quic_crypto_client_stream_factory, QuicRandom* random_generator, QuicClock* clock) - : host_resolver_(host_resolver), + : require_confirmation_(true), + host_resolver_(host_resolver), client_socket_factory_(client_socket_factory), + http_server_properties_(http_server_properties), quic_crypto_client_stream_factory_(quic_crypto_client_stream_factory), random_generator_(random_generator), clock_(clock), @@ -301,6 +297,8 @@ int QuicStreamFactory::Create(const HostPortProxyPair& host_port_proxy_pair, void QuicStreamFactory::OnJobComplete(Job* job, int rv) { if (rv == OK) { + require_confirmation_ = false; + // Create all the streams, but do not notify them yet. for (RequestSet::iterator it = job_requests_map_[job].begin(); it != job_requests_map_[job].end() ; ++it) { @@ -351,6 +349,13 @@ void QuicStreamFactory::OnSessionClose(QuicClientSession* session) { DCHECK(active_sessions_.count(*it)); DCHECK_EQ(session, active_sessions_[*it]); active_sessions_.erase(*it); + if (!session->IsCryptoHandshakeConfirmed() && http_server_properties_) { + // TODO(rch): In the special case where the session has received no + // packets from the peer, we should consider blacklisting this + // differently so that we still race TCP but we don't consider the + // session connected until the handshake has been confirmed. + http_server_properties_->SetBrokenAlternateProtocol(it->first); + } } all_sessions_.erase(session); session_aliases_.erase(session); @@ -393,6 +398,7 @@ base::Value* QuicStreamFactory::QuicStreamFactoryInfoToValue() const { void QuicStreamFactory::OnIPAddressChanged() { CloseAllSessions(ERR_NETWORK_CHANGED); + require_confirmation_ = true; } bool QuicStreamFactory::HasActiveSession( @@ -408,10 +414,10 @@ QuicClientSession* QuicStreamFactory::CreateSession( const BoundNetLog& net_log) { QuicGuid guid = random_generator_->RandUint64(); IPEndPoint addr = *address_list.begin(); - DatagramClientSocket* socket = + scoped_ptr<DatagramClientSocket> socket( client_socket_factory_->CreateDatagramClientSocket( DatagramSocket::DEFAULT_BIND, base::Bind(&base::RandInt), - net_log.net_log(), net_log.source()); + net_log.net_log(), net_log.source())); socket->Connect(addr); // We should adaptively set this buffer size, but for now, we'll use a size @@ -420,24 +426,16 @@ QuicClientSession* QuicStreamFactory::CreateSession( // revisit this setting and test for its impact. const int32 kSocketBufferSize(kMaxPacketSize * 100); // Support 100 packets. socket->SetReceiveBufferSize(kSocketBufferSize); - // TODO(jar): What should the UDP send buffer be set to? If the send buffer - // is too large, then we might(?) wastefully queue packets in the OS, when - // we'd rather construct packets just in time. We do however expect that the - // calculated send rate (paced, or ack clocked), will be well below the egress - // rate of the local machine, so that *shouldn't* be a problem. - // If the buffer setting is too small, then we will starve our outgoing link - // on a fast connection, because we won't respond fast enough to the many - // async callbacks to get data from us. On the other hand, until we have real - // pacing support (beyond ack-clocked pacing), we get a bit of adhoc-pacing by - // requiring the application to refill this OS buffer (ensuring that we don't - // blast a pile of packets at the kernel's max egress rate). - // socket->SetSendBufferSize(????); + // Set a buffer large enough to contain the initial CWND's worth of packet + // to work around the problem with CHLO packets being sent out with the + // wrong encryption level, when the send buffer is full. + socket->SetSendBufferSize(kMaxPacketSize * 20); // Support 20 packets. QuicConnectionHelper* helper = new QuicConnectionHelper( base::MessageLoop::current()->message_loop_proxy().get(), clock_.get(), random_generator_, - socket); + socket.get()); QuicConnection* connection = new QuicConnection(guid, addr, helper, false, QuicVersionMax()); @@ -447,7 +445,7 @@ QuicClientSession* QuicStreamFactory::CreateSession( DCHECK(crypto_config); QuicClientSession* session = - new QuicClientSession(connection, socket, this, + new QuicClientSession(connection, socket.Pass(), this, quic_crypto_client_stream_factory_, host_port_proxy_pair.first.host(), config_, crypto_config, net_log.net_log()); diff --git a/chromium/net/quic/quic_stream_factory.h b/chromium/net/quic/quic_stream_factory.h index 963a60369ac..2319e728ed8 100644 --- a/chromium/net/quic/quic_stream_factory.h +++ b/chromium/net/quic/quic_stream_factory.h @@ -25,6 +25,7 @@ namespace net { class CertVerifier; class ClientSocketFactory; class HostResolver; +class HttpServerProperties; class QuicClock; class QuicClientSession; class QuicCryptoClientStreamFactory; @@ -76,6 +77,7 @@ class NET_EXPORT_PRIVATE QuicStreamFactory QuicStreamFactory( HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, + base::WeakPtr<HttpServerProperties> http_server_properties, QuicCryptoClientStreamFactory* quic_crypto_client_stream_factory, QuicRandom* random_generator, QuicClock* clock); @@ -120,6 +122,12 @@ class NET_EXPORT_PRIVATE QuicStreamFactory // IP address changes. virtual void OnIPAddressChanged() OVERRIDE; + bool require_confirmation() const { return require_confirmation_; } + + void set_require_confirmation(bool require_confirmation) { + require_confirmation_ = require_confirmation; + } + private: class Job; @@ -148,8 +156,10 @@ class NET_EXPORT_PRIVATE QuicStreamFactory QuicCryptoClientConfig* GetOrCreateCryptoConfig( const HostPortProxyPair& host_port_proxy_pair); + bool require_confirmation_; HostResolver* host_resolver_; ClientSocketFactory* client_socket_factory_; + base::WeakPtr<HttpServerProperties> http_server_properties_; QuicCryptoClientStreamFactory* quic_crypto_client_stream_factory_; QuicRandom* random_generator_; scoped_ptr<QuicClock> clock_; diff --git a/chromium/net/quic/quic_stream_factory_test.cc b/chromium/net/quic/quic_stream_factory_test.cc index 2f46772d0fc..a546d874ea1 100644 --- a/chromium/net/quic/quic_stream_factory_test.cc +++ b/chromium/net/quic/quic_stream_factory_test.cc @@ -29,12 +29,14 @@ class QuicStreamFactoryTest : public ::testing::Test { QuicStreamFactoryTest() : clock_(new MockClock()), factory_(&host_resolver_, &socket_factory_, + base::WeakPtr<HttpServerProperties>(), &crypto_client_stream_factory_, &random_generator_, clock_), host_port_proxy_pair_(HostPortPair("www.google.com", 443), ProxyServer::Direct()), is_https_(false), cert_verifier_(CertVerifier::CreateDefault()) { + factory_.set_require_confirmation(false); } scoped_ptr<QuicEncryptedPacket> ConstructRstPacket( @@ -45,11 +47,12 @@ class QuicStreamFactoryTest : public ::testing::Test { header.public_header.reset_flag = false; header.public_header.version_flag = true; header.packet_sequence_number = num; + header.public_header.sequence_number_length = PACKET_1BYTE_SEQUENCE_NUMBER; header.entropy_flag = false; header.fec_flag = false; header.fec_group = 0; - QuicRstStreamFrame rst(stream_id, QUIC_STREAM_NO_ERROR); + QuicRstStreamFrame rst(stream_id, QUIC_ERROR_PROCESSING_STREAM); return scoped_ptr<QuicEncryptedPacket>( ConstructPacket(header, QuicFrame(&rst))); } @@ -171,7 +174,12 @@ TEST_F(QuicStreamFactoryTest, MaxOpenStream) { MockRead reads[] = { MockRead(ASYNC, OK, 0) // EOF }; - DeterministicSocketData socket_data(reads, arraysize(reads), NULL, 0); + scoped_ptr<QuicEncryptedPacket> rst(ConstructRstPacket(1, 3)); + MockWrite writes[] = { + MockWrite(ASYNC, rst->data(), rst->length(), 1), + }; + DeterministicSocketData socket_data(reads, arraysize(reads), + writes, arraysize(writes)); socket_factory_.AddSocketDataProvider(&socket_data); socket_data.StopAfter(1); @@ -342,6 +350,7 @@ TEST_F(QuicStreamFactoryTest, OnIPAddressChanged) { factory_.OnIPAddressChanged(); EXPECT_EQ(ERR_NETWORK_CHANGED, stream->ReadResponseHeaders(callback_.callback())); + EXPECT_TRUE(factory_.require_confirmation()); // Now attempting to request a stream to the same origin should create // a new session. diff --git a/chromium/net/quic/quic_stream_sequencer.cc b/chromium/net/quic/quic_stream_sequencer.cc index 7cf67d351e3..a57c05f83c4 100644 --- a/chromium/net/quic/quic_stream_sequencer.cc +++ b/chromium/net/quic/quic_stream_sequencer.cc @@ -74,17 +74,21 @@ bool QuicStreamSequencer::OnStreamFrame(const QuicStreamFrame& frame) { return true; } - if (frame.fin) { - CloseStreamAtOffset(frame.offset + frame.data.size()); - } - QuicStreamOffset byte_offset = frame.offset; const char* data = frame.data.data(); size_t data_len = frame.data.size(); - if (data_len == 0) { - // TODO(rch): Close the stream if there was no data and no fin. - return true; + if (data_len == 0 && !frame.fin) { + // Stream frames must have data or a fin flag. + stream_->ConnectionClose(QUIC_INVALID_STREAM_FRAME, false); + return false; + } + + if (frame.fin) { + CloseStreamAtOffset(frame.offset + frame.data.size()); + if (data_len == 0) { + return true; + } } if (byte_offset == num_bytes_consumed_) { @@ -96,7 +100,7 @@ bool QuicStreamSequencer::OnStreamFrame(const QuicStreamFrame& frame) { return true; } if (bytes_consumed > data_len) { - stream_->Close(QUIC_SERVER_ERROR_PROCESSING_STREAM); + stream_->Close(QUIC_ERROR_PROCESSING_STREAM); return false; } else if (bytes_consumed == data_len) { FlushBufferedFrames(); @@ -211,7 +215,7 @@ void QuicStreamSequencer::MarkConsumed(size_t num_bytes_consumed) { << " end_offset: " << end_offset << " offset: " << it->first << " length: " << it->second.length(); - stream_->Close(QUIC_SERVER_ERROR_PROCESSING_STREAM); + stream_->Close(QUIC_ERROR_PROCESSING_STREAM); return; } @@ -262,7 +266,7 @@ void QuicStreamSequencer::FlushBufferedFrames() { return; } if (bytes_consumed > data->size()) { - stream_->Close(QUIC_SERVER_ERROR_PROCESSING_STREAM); // Programming error + stream_->Close(QUIC_ERROR_PROCESSING_STREAM); // Programming error return; } else if (bytes_consumed == data->size()) { frames_.erase(it); diff --git a/chromium/net/quic/quic_stream_sequencer.h b/chromium/net/quic/quic_stream_sequencer.h index fe9fba56830..a450bef988d 100644 --- a/chromium/net/quic/quic_stream_sequencer.h +++ b/chromium/net/quic/quic_stream_sequencer.h @@ -29,8 +29,6 @@ class ReliableQuicStream; // TOOD(alyssar) add some checks for overflow attempts [1, 256,] [2, 256] class NET_EXPORT_PRIVATE QuicStreamSequencer { public: - static size_t kMaxUdpPacketSize; - explicit QuicStreamSequencer(ReliableQuicStream* quic_stream); QuicStreamSequencer(size_t max_frame_memory, ReliableQuicStream* quic_stream); diff --git a/chromium/net/quic/quic_stream_sequencer_test.cc b/chromium/net/quic/quic_stream_sequencer_test.cc index 0d40db92636..21568d60653 100644 --- a/chromium/net/quic/quic_stream_sequencer_test.cc +++ b/chromium/net/quic/quic_stream_sequencer_test.cc @@ -58,11 +58,8 @@ class QuicStreamSequencerPeer : public QuicStreamSequencer { void SetMemoryLimit(size_t limit) { max_frame_memory_ = limit; } - - const ReliableQuicStream* stream() const { return stream_; } uint64 num_bytes_consumed() const { return num_bytes_consumed_; } const FrameMap* frames() const { return &frames_; } - int32 max_frame_memory() const { return max_frame_memory_; } QuicStreamOffset close_offset() const { return close_offset_; } }; @@ -74,6 +71,7 @@ class MockStream : public ReliableQuicStream { MOCK_METHOD1(TerminateFromPeer, void(bool half_close)); MOCK_METHOD2(ProcessData, uint32(const char* data, uint32 data_len)); + MOCK_METHOD2(ConnectionClose, void(QuicErrorCode error, bool from_peer)); MOCK_METHOD1(Close, void(QuicRstStreamErrorCode error)); MOCK_METHOD0(OnCanWrite, void()); }; @@ -185,7 +183,8 @@ TEST_F(QuicStreamSequencerTest, FullFrameConsumed) { } TEST_F(QuicStreamSequencerTest, EmptyFrame) { - EXPECT_TRUE(sequencer_->OnFrame(0, "")); + EXPECT_CALL(stream_, ConnectionClose(QUIC_INVALID_STREAM_FRAME, false)); + EXPECT_FALSE(sequencer_->OnFrame(0, "")); EXPECT_EQ(0u, sequencer_->frames()->size()); EXPECT_EQ(0u, sequencer_->num_bytes_consumed()); } @@ -402,7 +401,7 @@ TEST_F(QuicStreamSequencerTest, MarkConsumedError) { // Now, attempt to mark consumed more data than was readable // and expect the stream to be closed. - EXPECT_CALL(stream_, Close(QUIC_SERVER_ERROR_PROCESSING_STREAM)); + EXPECT_CALL(stream_, Close(QUIC_ERROR_PROCESSING_STREAM)); EXPECT_DFATAL(sequencer_->MarkConsumed(4), "Invalid argument to MarkConsumed. num_bytes_consumed_: 3 " "end_offset: 4 offset: 9 length: 17"); diff --git a/chromium/net/quic/quic_time_test.cc b/chromium/net/quic/quic_time_test.cc index c4ea0e20047..18b1de4f3d3 100644 --- a/chromium/net/quic/quic_time_test.cc +++ b/chromium/net/quic/quic_time_test.cc @@ -9,23 +9,19 @@ namespace net { namespace test { -class QuicTimeDeltaTest : public ::testing::Test { - protected: -}; - -TEST_F(QuicTimeDeltaTest, Zero) { +TEST(QuicTimeDeltaTest, Zero) { EXPECT_TRUE(QuicTime::Delta::Zero().IsZero()); EXPECT_FALSE(QuicTime::Delta::Zero().IsInfinite()); EXPECT_FALSE(QuicTime::Delta::FromMilliseconds(1).IsZero()); } -TEST_F(QuicTimeDeltaTest, Infinite) { +TEST(QuicTimeDeltaTest, Infinite) { EXPECT_TRUE(QuicTime::Delta::Infinite().IsInfinite()); EXPECT_FALSE(QuicTime::Delta::Zero().IsInfinite()); EXPECT_FALSE(QuicTime::Delta::FromMilliseconds(1).IsInfinite()); } -TEST_F(QuicTimeDeltaTest, FromTo) { +TEST(QuicTimeDeltaTest, FromTo) { EXPECT_EQ(QuicTime::Delta::FromMilliseconds(1), QuicTime::Delta::FromMicroseconds(1000)); EXPECT_EQ(QuicTime::Delta::FromSeconds(1), @@ -41,17 +37,24 @@ TEST_F(QuicTimeDeltaTest, FromTo) { QuicTime::Delta::FromSeconds(2).ToMicroseconds()); } -TEST_F(QuicTimeDeltaTest, Add) { +TEST(QuicTimeDeltaTest, Add) { EXPECT_EQ(QuicTime::Delta::FromMicroseconds(2000), QuicTime::Delta::Zero().Add(QuicTime::Delta::FromMilliseconds(2))); } -TEST_F(QuicTimeDeltaTest, Subtract) { +TEST(QuicTimeDeltaTest, Subtract) { EXPECT_EQ(QuicTime::Delta::FromMicroseconds(1000), QuicTime::Delta::FromMilliseconds(2).Subtract( QuicTime::Delta::FromMilliseconds(1))); } +TEST(QuicTimeDeltaTest, NotEqual) { + EXPECT_TRUE(QuicTime::Delta::FromSeconds(0) != + QuicTime::Delta::FromSeconds(1)); + EXPECT_FALSE(QuicTime::Delta::FromSeconds(0) != + QuicTime::Delta::FromSeconds(0)); +} + class QuicTimeTest : public ::testing::Test { protected: MockClock clock_; @@ -109,5 +112,14 @@ TEST_F(QuicTimeTest, MockClock) { EXPECT_EQ(now, time); } +TEST_F(QuicTimeTest, LE) { + const QuicTime zero = QuicTime::Zero(); + const QuicTime one = zero.Add(QuicTime::Delta::FromSeconds(1)); + EXPECT_TRUE(zero <= zero); + EXPECT_TRUE(zero <= one); + EXPECT_TRUE(one <= one); + EXPECT_FALSE(one <= zero); +} + } // namespace test } // namespace net diff --git a/chromium/net/quic/quic_utils.cc b/chromium/net/quic/quic_utils.cc index 17957440a3d..cfb3cb7db5c 100644 --- a/chromium/net/quic/quic_utils.cc +++ b/chromium/net/quic/quic_utils.cc @@ -111,14 +111,6 @@ void QuicUtils::SerializeUint128(uint128 v, uint8* out) { memcpy(out + sizeof(lo), &hi, sizeof(hi)); } -// static -uint128 QuicUtils::ParseUint128(const uint8* in) { - uint64 lo, hi; - memcpy(&lo, in, sizeof(lo)); - memcpy(&hi, in + sizeof(lo), sizeof(hi)); - return uint128(hi, lo); -} - #define RETURN_STRING_LITERAL(x) \ case x: \ return #x; @@ -128,10 +120,11 @@ const char* QuicUtils::StreamErrorToString(QuicRstStreamErrorCode error) { switch (error) { RETURN_STRING_LITERAL(QUIC_STREAM_NO_ERROR); RETURN_STRING_LITERAL(QUIC_STREAM_CONNECTION_ERROR); - RETURN_STRING_LITERAL(QUIC_SERVER_ERROR_PROCESSING_STREAM); + RETURN_STRING_LITERAL(QUIC_ERROR_PROCESSING_STREAM); RETURN_STRING_LITERAL(QUIC_MULTIPLE_TERMINATION_OFFSETS); RETURN_STRING_LITERAL(QUIC_BAD_APPLICATION_PAYLOAD); RETURN_STRING_LITERAL(QUIC_STREAM_PEER_GOING_AWAY); + RETURN_STRING_LITERAL(QUIC_STREAM_CANCELLED); RETURN_STRING_LITERAL(QUIC_STREAM_LAST_ERROR); } // Return a default value so that we return this when |error| doesn't match @@ -148,11 +141,14 @@ const char* QuicUtils::ErrorToString(QuicErrorCode error) { RETURN_STRING_LITERAL(QUIC_STREAM_DATA_AFTER_TERMINATION); RETURN_STRING_LITERAL(QUIC_INVALID_PACKET_HEADER); RETURN_STRING_LITERAL(QUIC_INVALID_FRAME_DATA); + RETURN_STRING_LITERAL(QUIC_MISSING_PAYLOAD); RETURN_STRING_LITERAL(QUIC_INVALID_FEC_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_STREAM_DATA); RETURN_STRING_LITERAL(QUIC_INVALID_RST_STREAM_DATA); RETURN_STRING_LITERAL(QUIC_INVALID_CONNECTION_CLOSE_DATA); RETURN_STRING_LITERAL(QUIC_INVALID_GOAWAY_DATA); RETURN_STRING_LITERAL(QUIC_INVALID_ACK_DATA); + RETURN_STRING_LITERAL(QUIC_INVALID_CONGESTION_FEEDBACK_DATA); RETURN_STRING_LITERAL(QUIC_INVALID_VERSION_NEGOTIATION_PACKET); RETURN_STRING_LITERAL(QUIC_INVALID_PUBLIC_RST_PACKET); RETURN_STRING_LITERAL(QUIC_DECRYPTION_FAILURE); @@ -175,6 +171,7 @@ const char* QuicUtils::ErrorToString(QuicErrorCode error) { RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_PARAMETER_NO_OVERLAP); RETURN_STRING_LITERAL(QUIC_CRYPTO_MESSAGE_INDEX_NOT_FOUND); RETURN_STRING_LITERAL(QUIC_INVALID_STREAM_ID); + RETURN_STRING_LITERAL(QUIC_INVALID_PRIORITY); RETURN_STRING_LITERAL(QUIC_TOO_MANY_OPEN_STREAMS); RETURN_STRING_LITERAL(QUIC_PUBLIC_RESET); RETURN_STRING_LITERAL(QUIC_INVALID_VERSION); @@ -185,6 +182,8 @@ const char* QuicUtils::ErrorToString(QuicErrorCode error) { RETURN_STRING_LITERAL(QUIC_CONNECTION_TIMED_OUT); RETURN_STRING_LITERAL(QUIC_ERROR_MIGRATING_ADDRESS); RETURN_STRING_LITERAL(QUIC_PACKET_WRITE_ERROR); + RETURN_STRING_LITERAL(QUIC_PACKET_READ_ERROR); + RETURN_STRING_LITERAL(QUIC_INVALID_STREAM_FRAME); RETURN_STRING_LITERAL(QUIC_PROOF_INVALID); RETURN_STRING_LITERAL(QUIC_CRYPTO_DUPLICATE_TAG); RETURN_STRING_LITERAL(QUIC_CRYPTO_ENCRYPTION_LEVEL_INCORRECT); diff --git a/chromium/net/quic/quic_utils.h b/chromium/net/quic/quic_utils.h index 6650dbf453b..37eb9d208ac 100644 --- a/chromium/net/quic/quic_utils.h +++ b/chromium/net/quic/quic_utils.h @@ -46,9 +46,6 @@ class NET_EXPORT_PRIVATE QuicUtils { // SerializeUint128 writes |v| in little-endian form to |out|. static void SerializeUint128(uint128 v, uint8* out); - // ParseUint128 parses a little-endian uint128 from |in| and returns it. - static uint128 ParseUint128(const uint8* in); - // Returns the name of the QuicRstStreamErrorCode as a char* static const char* StreamErrorToString(QuicRstStreamErrorCode error); diff --git a/chromium/net/quic/reliable_quic_stream.cc b/chromium/net/quic/reliable_quic_stream.cc index c2cb3f1ff52..41cb4576b58 100644 --- a/chromium/net/quic/reliable_quic_stream.cc +++ b/chromium/net/quic/reliable_quic_stream.cc @@ -6,12 +6,43 @@ #include "net/quic/quic_session.h" #include "net/quic/quic_spdy_decompressor.h" +#include "net/spdy/write_blocked_list.h" using base::StringPiece; using std::min; namespace net { +namespace { + +// This is somewhat arbitrary. It's possible, but unlikely, we will either fail +// to set a priority client-side, or cancel a stream before stripping the +// priority from the wire server-side. In either case, start out with a +// priority in the middle. +QuicPriority kDefaultPriority = 3; + +// Appends bytes from data into partial_data_buffer. Once partial_data_buffer +// reaches 4 bytes, copies the data into 'result' and clears +// partial_data_buffer. +// Returns the number of bytes consumed. +uint32 StripUint32(const char* data, uint32 data_len, + string* partial_data_buffer, + uint32* result) { + DCHECK_GT(4u, partial_data_buffer->length()); + size_t missing_size = 4 - partial_data_buffer->length(); + if (data_len < missing_size) { + StringPiece(data, data_len).AppendToString(partial_data_buffer); + return data_len; + } + StringPiece(data, missing_size).AppendToString(partial_data_buffer); + DCHECK_EQ(4u, partial_data_buffer->length()); + memcpy(result, partial_data_buffer->data(), 4); + partial_data_buffer->clear(); + return missing_size; +} + +} // namespace + ReliableQuicStream::ReliableQuicStream(QuicStreamId id, QuicSession* session) : sequencer_(this), @@ -21,12 +52,14 @@ ReliableQuicStream::ReliableQuicStream(QuicStreamId id, stream_bytes_read_(0), stream_bytes_written_(0), headers_decompressed_(false), + priority_(kDefaultPriority), headers_id_(0), decompression_failed_(false), stream_error_(QUIC_STREAM_NO_ERROR), connection_error_(QUIC_NO_ERROR), read_side_closed_(false), write_side_closed_(false), + priority_parsed_(false), fin_buffered_(false), fin_sent_(false) { } @@ -161,6 +194,12 @@ QuicConsumedData ReliableQuicStream::WriteData(StringPiece data, bool fin) { return WriteOrBuffer(data, fin); } + +void ReliableQuicStream::set_priority(QuicPriority priority) { + DCHECK_EQ(0u, stream_bytes_written_); + priority_ = priority; +} + QuicConsumedData ReliableQuicStream::WriteOrBuffer(StringPiece data, bool fin) { DCHECK(!fin_buffered_); @@ -203,27 +242,43 @@ void ReliableQuicStream::OnCanWrite() { QuicConsumedData ReliableQuicStream::WriteDataInternal( StringPiece data, bool fin) { + struct iovec iov = {const_cast<char*>(data.data()), + static_cast<size_t>(data.size())}; + return WritevDataInternal(&iov, 1, fin); +} + +QuicConsumedData ReliableQuicStream::WritevDataInternal(const struct iovec* iov, + int iov_count, + bool fin) { if (write_side_closed_) { DLOG(ERROR) << "Attempt to write when the write side is closed"; return QuicConsumedData(0, false); } + size_t write_length = 0u; + for (int i = 0; i < iov_count; ++i) { + write_length += iov[i].iov_len; + } QuicConsumedData consumed_data = - session()->WriteData(id(), data, stream_bytes_written_, fin); + session()->WritevData(id(), iov, iov_count, stream_bytes_written_, fin); stream_bytes_written_ += consumed_data.bytes_consumed; - if (consumed_data.bytes_consumed == data.length()) { + if (consumed_data.bytes_consumed == write_length) { if (fin && consumed_data.fin_consumed) { fin_sent_ = true; CloseWriteSide(); } else if (fin && !consumed_data.fin_consumed) { - session_->MarkWriteBlocked(id()); + session_->MarkWriteBlocked(id(), EffectivePriority()); } } else { - session_->MarkWriteBlocked(id()); + session_->MarkWriteBlocked(id(), EffectivePriority()); } return consumed_data; } +QuicPriority ReliableQuicStream::EffectivePriority() const { + return priority(); +} + void ReliableQuicStream::CloseReadSide() { if (read_side_closed_) { return; @@ -238,35 +293,22 @@ void ReliableQuicStream::CloseReadSide() { } uint32 ReliableQuicStream::ProcessRawData(const char* data, uint32 data_len) { + DCHECK_NE(0u, data_len); if (id() == kCryptoStreamId) { - if (data_len == 0) { - return 0; - } // The crypto stream does not use compression. return ProcessData(data, data_len); } + uint32 total_bytes_consumed = 0; if (headers_id_ == 0u) { - // The headers ID has not yet been read. Strip it from the beginning of - // the data stream. - DCHECK_GT(4u, headers_id_buffer_.length()); - size_t missing_size = 4 - headers_id_buffer_.length(); - if (data_len < missing_size) { - StringPiece(data, data_len).AppendToString(&headers_id_buffer_); - return data_len; + total_bytes_consumed += StripPriorityAndHeaderId(data, data_len); + data += total_bytes_consumed; + data_len -= total_bytes_consumed; + if (data_len == 0 || !session_->connection()->connected()) { + return total_bytes_consumed; } - total_bytes_consumed += missing_size; - StringPiece(data, missing_size).AppendToString(&headers_id_buffer_); - DCHECK_EQ(4u, headers_id_buffer_.length()); - memcpy(&headers_id_, headers_id_buffer_.data(), 4); - headers_id_buffer_.clear(); - data += missing_size; - data_len -= missing_size; } DCHECK_NE(0u, headers_id_); - if (data_len == 0) { - return total_bytes_consumed; - } // Once the headers are finished, we simply pass the data through. if (headers_decompressed_) { @@ -420,6 +462,10 @@ void ReliableQuicStream::CloseWriteSide() { } } +bool ReliableQuicStream::HasBufferedData() { + return !queued_data_.empty(); +} + void ReliableQuicStream::OnClose() { CloseReadSide(); CloseWriteSide(); @@ -433,4 +479,36 @@ void ReliableQuicStream::OnClose() { } } +uint32 ReliableQuicStream::StripPriorityAndHeaderId( + const char* data, uint32 data_len) { + uint32 total_bytes_parsed = 0; + + if (!priority_parsed_ && + session_->connection()->version() >= QUIC_VERSION_9 && + session_->connection()->is_server()) { + QuicPriority temporary_priority = priority_; + total_bytes_parsed = StripUint32( + data, data_len, &headers_id_and_priority_buffer_, &temporary_priority); + if (total_bytes_parsed > 0 && headers_id_and_priority_buffer_.size() == 0) { + priority_parsed_ = true; + // Spdy priorities are inverted, so the highest numerical value is the + // lowest legal priority. + if (temporary_priority > static_cast<QuicPriority>(kLowestPriority)) { + session_->connection()->SendConnectionClose(QUIC_INVALID_PRIORITY); + return 0; + } + priority_ = temporary_priority; + } + data += total_bytes_parsed; + data_len -= total_bytes_parsed; + } + if (data_len > 0 && headers_id_ == 0u) { + // The headers ID has not yet been read. Strip it from the beginning of + // the data stream. + total_bytes_parsed += StripUint32( + data, data_len, &headers_id_and_priority_buffer_, &headers_id_); + } + return total_bytes_parsed; +} + } // namespace net diff --git a/chromium/net/quic/reliable_quic_stream.h b/chromium/net/quic/reliable_quic_stream.h index 352325de36c..807882ca2de 100644 --- a/chromium/net/quic/reliable_quic_stream.h +++ b/chromium/net/quic/reliable_quic_stream.h @@ -95,6 +95,12 @@ class NET_EXPORT_PRIVATE ReliableQuicStream : public // becomes unblocked. virtual void OnDecompressorAvailable(); + // By default, this is the same as priority(), however it allows streams + // to temporarily alter effective priority. For example if a SPDY stream has + // compressed but not written headers it can write the headers with a higher + // priority. + virtual QuicPriority EffectivePriority() const; + QuicStreamId id() const { return id_; } QuicRstStreamErrorCode stream_error() const { return stream_error_; } @@ -108,7 +114,6 @@ class NET_EXPORT_PRIVATE ReliableQuicStream : public const IPEndPoint& GetPeerAddress() const; - Visitor* visitor() { return visitor_; } void set_visitor(Visitor* visitor) { visitor_ = visitor; } QuicSpdyCompressor* compressor(); @@ -135,10 +140,19 @@ class NET_EXPORT_PRIVATE ReliableQuicStream : public // Close the write side of the socket. Further writes will fail. void CloseWriteSide(); + bool HasBufferedData(); + bool fin_buffered() { return fin_buffered_; } QuicSession* session() { return session_; } + // Sets priority_ to priority. This should only be called before bytes are + // written to the server. + void set_priority(QuicPriority priority); + // This is protected because external classes should use EffectivePriority + // instead. + QuicPriority priority() const { return priority_; } + // Sends as much of 'data' to the connection as the connection will consume, // and then buffers any remaining data in queued_data_. // Returns (data.size(), true) as it always consumed all data: it returns for @@ -149,10 +163,19 @@ class NET_EXPORT_PRIVATE ReliableQuicStream : public // Returns the number of bytes consumed by the connection. QuicConsumedData WriteDataInternal(base::StringPiece data, bool fin); + // Sends as many bytes in the first |count| buffers of |iov| to the connection + // as the connection will consume. + // Returns the number of bytes consumed by the connection. + QuicConsumedData WritevDataInternal(const struct iovec* iov, + int iov_count, + bool fin); + private: friend class test::ReliableQuicStreamPeer; friend class QuicStreamUtils; + uint32 StripPriorityAndHeaderId(const char* data, uint32 data_len); + std::list<string> queued_data_; QuicStreamSequencer sequencer_; @@ -166,11 +189,13 @@ class NET_EXPORT_PRIVATE ReliableQuicStream : public uint64 stream_bytes_written_; // True if the headers have been completely decompresssed. bool headers_decompressed_; + // The priority of the stream, once parsed. + QuicPriority priority_; // ID of the header block sent by the peer, once parsed. QuicHeaderId headers_id_; - // Buffer into which we write bytes from the headers_id_ - // until it is fully parsed. - string headers_id_buffer_; + // Buffer into which we write bytes from priority_ and headers_id_ + // until each is fully parsed. + string headers_id_and_priority_buffer_; // Contains a copy of the decompressed headers_ until they are consumed // via ProcessData or Readv. string decompressed_headers_; @@ -190,6 +215,8 @@ class NET_EXPORT_PRIVATE ReliableQuicStream : public // True if the write side is closed, and further writes should fail. bool write_side_closed_; + // True if the priority has been read, false otherwise. + bool priority_parsed_; bool fin_buffered_; bool fin_sent_; }; diff --git a/chromium/net/quic/reliable_quic_stream_test.cc b/chromium/net/quic/reliable_quic_stream_test.cc index 7167a222345..063554d82c9 100644 --- a/chromium/net/quic/reliable_quic_stream_test.cc +++ b/chromium/net/quic/reliable_quic_stream_test.cc @@ -116,7 +116,7 @@ class ReliableQuicStreamTest : public ::testing::TestWithParam<bool> { scoped_ptr<QuicSpdyCompressor> compressor_; scoped_ptr<QuicSpdyDecompressor> decompressor_; SpdyHeaderBlock headers_; - BlockedList<QuicStreamId>* write_blocked_list_; + WriteBlockedList<QuicStreamId>* write_blocked_list_; }; TEST_F(ReliableQuicStreamTest, WriteAllData) { @@ -126,12 +126,10 @@ TEST_F(ReliableQuicStreamTest, WriteAllData) { 1 + QuicPacketCreator::StreamFramePacketOverhead( connection_->version(), PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP); - // TODO(rch): figure out how to get StrEq working here. - //EXPECT_CALL(*session_, WriteData(kStreamId, StrEq(kData1), _, _)).WillOnce( - EXPECT_CALL(*session_, WriteData(kStreamId, _, _, _)).WillOnce( + EXPECT_CALL(*session_, WritevData(kStreamId, _, 1, _, _)).WillOnce( Return(QuicConsumedData(kDataLen, true))); EXPECT_EQ(kDataLen, stream_->WriteData(kData1, false).bytes_consumed); - EXPECT_TRUE(write_blocked_list_->IsEmpty()); + EXPECT_FALSE(write_blocked_list_->HasWriteBlockedStreams()); } // TODO(rtenneti): Death tests crash on OS_ANDROID. @@ -142,10 +140,10 @@ TEST_F(ReliableQuicStreamTest, NoBlockingIfNoDataOrFin) { // Write no data and no fin. If we consume nothing we should not be write // blocked. EXPECT_DEBUG_DEATH({ - EXPECT_CALL(*session_, WriteData(kStreamId, _, _, _)).WillOnce( + EXPECT_CALL(*session_, WritevData(kStreamId, _, 1, _, _)).WillOnce( Return(QuicConsumedData(0, false))); stream_->WriteData(StringPiece(), false); - EXPECT_TRUE(write_blocked_list_->IsEmpty()); + EXPECT_FALSE(write_blocked_list_->HasWriteBlockedStreams()); }, ""); } #endif // GTEST_HAS_DEATH_TEST && !defined(NDEBUG) && !defined(OS_ANDROID) @@ -155,10 +153,10 @@ TEST_F(ReliableQuicStreamTest, BlockIfOnlySomeDataConsumed) { // Write some data and no fin. If we consume some but not all of the data, // we should be write blocked a not all the data was consumed. - EXPECT_CALL(*session_, WriteData(kStreamId, _, _, _)).WillOnce( + EXPECT_CALL(*session_, WritevData(kStreamId, _, 1, _, _)).WillOnce( Return(QuicConsumedData(1, false))); stream_->WriteData(StringPiece(kData1, 2), false); - ASSERT_EQ(1, write_blocked_list_->NumObjects()); + ASSERT_EQ(1, write_blocked_list_->NumBlockedStreams()); } @@ -169,10 +167,10 @@ TEST_F(ReliableQuicStreamTest, BlockIfFinNotConsumedWithData) { // we should be write blocked because the fin was not consumed. // (This should never actually happen as the fin should be sent out with the // last data) - EXPECT_CALL(*session_, WriteData(kStreamId, _, _, _)).WillOnce( + EXPECT_CALL(*session_, WritevData(kStreamId, _, 1, _, _)).WillOnce( Return(QuicConsumedData(2, false))); stream_->WriteData(StringPiece(kData1, 2), true); - ASSERT_EQ(1, write_blocked_list_->NumObjects()); + ASSERT_EQ(1, write_blocked_list_->NumBlockedStreams()); } TEST_F(ReliableQuicStreamTest, BlockIfSoloFinNotConsumed) { @@ -180,44 +178,39 @@ TEST_F(ReliableQuicStreamTest, BlockIfSoloFinNotConsumed) { // Write no data and a fin. If we consume nothing we should be write blocked, // as the fin was not consumed. - EXPECT_CALL(*session_, WriteData(kStreamId, _, _, _)).WillOnce( + EXPECT_CALL(*session_, WritevData(kStreamId, _, 1, _, _)).WillOnce( Return(QuicConsumedData(0, false))); stream_->WriteData(StringPiece(), true); - ASSERT_EQ(1, write_blocked_list_->NumObjects()); + ASSERT_EQ(1, write_blocked_list_->NumBlockedStreams()); } TEST_F(ReliableQuicStreamTest, WriteData) { Initialize(kShouldProcessData); - EXPECT_TRUE(write_blocked_list_->IsEmpty()); + EXPECT_FALSE(write_blocked_list_->HasWriteBlockedStreams()); connection_->options()->max_packet_length = 1 + QuicPacketCreator::StreamFramePacketOverhead( connection_->version(), PACKET_8BYTE_GUID, !kIncludeVersion, PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP); - // TODO(rch): figure out how to get StrEq working here. - //EXPECT_CALL(*session_, WriteData(_, StrEq(kData1), _, _)).WillOnce( - EXPECT_CALL(*session_, WriteData(_, _, _, _)).WillOnce( + EXPECT_CALL(*session_, WritevData(_, _, 1, _, _)).WillOnce( Return(QuicConsumedData(kDataLen - 1, false))); // The return will be kDataLen, because the last byte gets buffered. EXPECT_EQ(kDataLen, stream_->WriteData(kData1, false).bytes_consumed); - EXPECT_FALSE(write_blocked_list_->IsEmpty()); + EXPECT_TRUE(write_blocked_list_->HasWriteBlockedStreams()); // Queue a bytes_consumed write. EXPECT_EQ(kDataLen, stream_->WriteData(kData2, false).bytes_consumed); // Make sure we get the tail of the first write followed by the bytes_consumed InSequence s; - //EXPECT_CALL(*session_, WriteData(_, StrEq(&kData1[kDataLen - 1]), _, _)). - EXPECT_CALL(*session_, WriteData(_, _, _, _)). + EXPECT_CALL(*session_, WritevData(_, _, 1, _, _)). WillOnce(Return(QuicConsumedData(1, false))); - //EXPECT_CALL(*session_, WriteData(_, StrEq(kData2), _, _)). - EXPECT_CALL(*session_, WriteData(_, _, _, _)). + EXPECT_CALL(*session_, WritevData(_, _, 1, _, _)). WillOnce(Return(QuicConsumedData(kDataLen - 2, false))); stream_->OnCanWrite(); - // And finally the end of the bytes_consumed - //EXPECT_CALL(*session_, WriteData(_, StrEq(&kData2[kDataLen - 2]), _, _)). - EXPECT_CALL(*session_, WriteData(_, _, _, _)). + // And finally the end of the bytes_consumed. + EXPECT_CALL(*session_, WritevData(_, _, 1, _, _)). WillOnce(Return(QuicConsumedData(2, true))); stream_->OnCanWrite(); } @@ -237,18 +230,22 @@ TEST_F(ReliableQuicStreamTest, ConnectionCloseAfterStreamClose) { TEST_F(ReliableQuicStreamTest, ProcessHeaders) { Initialize(kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); QuicStreamFrame frame(kStreamId, false, 0, compressed_headers); stream_->OnStreamFrame(frame); EXPECT_EQ(SpdyUtils::SerializeUncompressedHeaders(headers_), stream_->data()); + EXPECT_EQ(static_cast<QuicPriority>(kHighestPriority), + stream_->EffectivePriority()); } TEST_F(ReliableQuicStreamTest, ProcessHeadersWithInvalidHeaderId) { Initialize(kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); - compressed_headers.replace(0, 1, 1, '\xFF'); // Illegal header id. + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); + compressed_headers.replace(4, 1, 1, '\xFF'); // Illegal header id. QuicStreamFrame frame(kStreamId, false, 0, compressed_headers); EXPECT_CALL(*connection_, SendConnectionClose(QUIC_INVALID_HEADER_ID)); @@ -258,7 +255,8 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersWithInvalidHeaderId) { TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBody) { Initialize(kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); string body = "this is the body"; string data = compressed_headers + body; QuicStreamFrame frame(kStreamId, false, 0, data); @@ -271,7 +269,8 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBody) { TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBodyFragments) { Initialize(kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kLowestPriority, headers_); string body = "this is the body"; string data = compressed_headers + body; @@ -303,12 +302,15 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBodyFragments) { ASSERT_EQ(SpdyUtils::SerializeUncompressedHeaders(headers_) + body, stream_->data()) << "split_point: " << split_point; } + EXPECT_EQ(static_cast<QuicPriority>(kLowestPriority), + stream_->EffectivePriority()); } TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBodyReadv) { Initialize(!kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); string body = "this is the body"; string data = compressed_headers + body; QuicStreamFrame frame(kStreamId, false, 0, data); @@ -337,7 +339,8 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBodyReadv) { TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBodyIncrementalReadv) { Initialize(!kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); string body = "this is the body"; string data = compressed_headers + body; QuicStreamFrame frame(kStreamId, false, 0, data); @@ -362,7 +365,8 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersAndBodyIncrementalReadv) { TEST_F(ReliableQuicStreamTest, ProcessHeadersUsingReadvWithMultipleIovecs) { Initialize(!kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); string body = "this is the body"; string data = compressed_headers + body; QuicStreamFrame frame(kStreamId, false, 0, data); @@ -391,13 +395,15 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersUsingReadvWithMultipleIovecs) { TEST_F(ReliableQuicStreamTest, ProcessCorruptHeadersEarly) { Initialize(kShouldProcessData); - string compressed_headers1 = compressor_->CompressHeaders(headers_); + string compressed_headers1 = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); QuicStreamFrame frame1(stream_->id(), false, 0, compressed_headers1); string decompressed_headers1 = SpdyUtils::SerializeUncompressedHeaders(headers_); headers_["content-type"] = "text/plain"; - string compressed_headers2 = compressor_->CompressHeaders(headers_); + string compressed_headers2 = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); // Corrupt the compressed data. compressed_headers2[compressed_headers2.length() - 1] ^= 0xA1; QuicStreamFrame frame2(stream2_->id(), false, 0, compressed_headers2); @@ -429,13 +435,15 @@ TEST_F(ReliableQuicStreamTest, ProcessCorruptHeadersEarly) { TEST_F(ReliableQuicStreamTest, ProcessPartialHeadersEarly) { Initialize(kShouldProcessData); - string compressed_headers1 = compressor_->CompressHeaders(headers_); + string compressed_headers1 = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); QuicStreamFrame frame1(stream_->id(), false, 0, compressed_headers1); string decompressed_headers1 = SpdyUtils::SerializeUncompressedHeaders(headers_); headers_["content-type"] = "text/plain"; - string compressed_headers2 = compressor_->CompressHeaders(headers_); + string compressed_headers2 = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); string partial_compressed_headers = compressed_headers2.substr(0, compressed_headers2.length() / 2); QuicStreamFrame frame2(stream2_->id(), false, 0, partial_compressed_headers); @@ -478,13 +486,15 @@ TEST_F(ReliableQuicStreamTest, ProcessPartialHeadersEarly) { TEST_F(ReliableQuicStreamTest, ProcessHeadersEarly) { Initialize(kShouldProcessData); - string compressed_headers1 = compressor_->CompressHeaders(headers_); + string compressed_headers1 = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); QuicStreamFrame frame1(stream_->id(), false, 0, compressed_headers1); string decompressed_headers1 = SpdyUtils::SerializeUncompressedHeaders(headers_); headers_["content-type"] = "text/plain"; - string compressed_headers2 = compressor_->CompressHeaders(headers_); + string compressed_headers2 = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); QuicStreamFrame frame2(stream2_->id(), false, 0, compressed_headers2); string decompressed_headers2 = SpdyUtils::SerializeUncompressedHeaders(headers_); @@ -512,7 +522,8 @@ TEST_F(ReliableQuicStreamTest, ProcessHeadersEarly) { TEST_F(ReliableQuicStreamTest, ProcessHeadersDelay) { Initialize(!kShouldProcessData); - string compressed_headers = compressor_->CompressHeaders(headers_); + string compressed_headers = + compressor_->CompressHeadersWithPriority(kHighestPriority, headers_); QuicStreamFrame frame1(stream_->id(), false, 0, compressed_headers); string decompressed_headers = SpdyUtils::SerializeUncompressedHeaders(headers_); diff --git a/chromium/net/quic/test_tools/mock_crypto_client_stream.cc b/chromium/net/quic/test_tools/mock_crypto_client_stream.cc index 79c33531714..bfeb06f0629 100644 --- a/chromium/net/quic/test_tools/mock_crypto_client_stream.cc +++ b/chromium/net/quic/test_tools/mock_crypto_client_stream.cc @@ -51,6 +51,16 @@ bool MockCryptoClientStream::CryptoConnect() { return true; } +void MockCryptoClientStream::SendOnCryptoHandshakeEvent( + QuicSession::CryptoHandshakeEvent event) { + encryption_established_ = true; + if (event == QuicSession::HANDSHAKE_CONFIRMED) { + handshake_confirmed_ = true; + SetConfigNegotiated(); + } + session()->OnCryptoHandshakeEvent(event); +} + void MockCryptoClientStream::SetConfigNegotiated() { ASSERT_FALSE(session()->config()->negotiated()); QuicTagVector cgst; diff --git a/chromium/net/quic/test_tools/mock_crypto_client_stream.h b/chromium/net/quic/test_tools/mock_crypto_client_stream.h index 2b73b8fd584..ada1b6883ba 100644 --- a/chromium/net/quic/test_tools/mock_crypto_client_stream.h +++ b/chromium/net/quic/test_tools/mock_crypto_client_stream.h @@ -8,6 +8,7 @@ #include <string> #include "net/quic/crypto/crypto_handshake.h" +#include "net/quic/crypto/crypto_protocol.h" #include "net/quic/quic_crypto_client_stream.h" #include "net/quic/quic_session.h" @@ -46,6 +47,10 @@ class MockCryptoClientStream : public QuicCryptoClientStream { // QuicCryptoClientStream implementation. virtual bool CryptoConnect() OVERRIDE; + // Invokes the sessions's CryptoHandshakeEvent method with the specified + // event. + void SendOnCryptoHandshakeEvent(QuicSession::CryptoHandshakeEvent event); + HandshakeMode handshake_mode_; private: diff --git a/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.cc b/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.cc index 7578790e136..e54fb41e2f5 100644 --- a/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.cc +++ b/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.cc @@ -12,7 +12,8 @@ using std::string; namespace net { MockCryptoClientStreamFactory::MockCryptoClientStreamFactory() - : handshake_mode_(MockCryptoClientStream::CONFIRM_HANDSHAKE) { + : handshake_mode_(MockCryptoClientStream::CONFIRM_HANDSHAKE), + last_stream_(NULL) { } QuicCryptoClientStream* @@ -20,8 +21,9 @@ MockCryptoClientStreamFactory::CreateQuicCryptoClientStream( const string& server_hostname, QuicSession* session, QuicCryptoClientConfig* crypto_config) { - return new MockCryptoClientStream(server_hostname, session, crypto_config, - handshake_mode_); + last_stream_ = new MockCryptoClientStream(server_hostname, session, + crypto_config, handshake_mode_); + return last_stream_; } } // namespace net diff --git a/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.h b/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.h index e3f2a4aba5c..9d056cbc72d 100644 --- a/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.h +++ b/chromium/net/quic/test_tools/mock_crypto_client_stream_factory.h @@ -9,7 +9,6 @@ #include "net/quic/quic_crypto_client_stream.h" #include "net/quic/quic_crypto_client_stream_factory.h" -#include "net/quic/quic_session.h" #include "net/quic/test_tools/mock_crypto_client_stream.h" namespace net { @@ -30,8 +29,13 @@ class MockCryptoClientStreamFactory : public QuicCryptoClientStreamFactory { handshake_mode_ = handshake_mode; } + MockCryptoClientStream* last_stream() const { + return last_stream_; + } + private: MockCryptoClientStream::HandshakeMode handshake_mode_; + MockCryptoClientStream* last_stream_; }; } // namespace net diff --git a/chromium/net/quic/test_tools/quic_connection_peer.cc b/chromium/net/quic/test_tools/quic_connection_peer.cc index 610c505161b..5d715db4ef5 100644 --- a/chromium/net/quic/test_tools/quic_connection_peer.cc +++ b/chromium/net/quic/test_tools/quic_connection_peer.cc @@ -70,17 +70,15 @@ QuicTime::Delta QuicConnectionPeer::GetNetworkTimeout( bool QuicConnectionPeer::IsSavedForRetransmission( QuicConnection* connection, QuicPacketSequenceNumber sequence_number) { - return ContainsKey(connection->retransmission_map_, sequence_number); + return connection->sent_packet_manager_.IsUnacked(sequence_number); } // static size_t QuicConnectionPeer::GetRetransmissionCount( QuicConnection* connection, QuicPacketSequenceNumber sequence_number) { - QuicConnection::RetransmissionMap::iterator it = - connection->retransmission_map_.find(sequence_number); - DCHECK(connection->retransmission_map_.end() != it); - return it->second.number_retransmissions; + return connection->sent_packet_manager_.GetRetransmissionCount( + sequence_number); } // static @@ -127,6 +125,12 @@ void QuicConnectionPeer::SetSelfAddress(QuicConnection* connection, } // static +void QuicConnectionPeer::SetPeerAddress(QuicConnection* connection, + const IPEndPoint& peer_address) { + connection->peer_address_ = peer_address; +} + +// static void QuicConnectionPeer::SwapCrypters(QuicConnection* connection, QuicFramer* framer) { framer->SwapCryptersForTest(&connection->framer_); diff --git a/chromium/net/quic/test_tools/quic_connection_peer.h b/chromium/net/quic/test_tools/quic_connection_peer.h index 4438353ddaa..13ff14dc420 100644 --- a/chromium/net/quic/test_tools/quic_connection_peer.h +++ b/chromium/net/quic/test_tools/quic_connection_peer.h @@ -7,8 +7,8 @@ #include "base/basictypes.h" #include "net/base/ip_endpoint.h" +#include "net/quic/quic_connection_stats.h" #include "net/quic/quic_protocol.h" -#include "net/quic/quic_stats.h" namespace net { @@ -78,6 +78,9 @@ class QuicConnectionPeer { static void SetSelfAddress(QuicConnection* connection, const IPEndPoint& self_address); + static void SetPeerAddress(QuicConnection* connection, + const IPEndPoint& peer_address); + static void SwapCrypters(QuicConnection* connection, QuicFramer* framer); static void SetMaxPacketsPerRetransmissionAlarm(QuicConnection* connection, diff --git a/chromium/net/quic/test_tools/quic_framer_peer.cc b/chromium/net/quic/test_tools/quic_framer_peer.cc index 5ec52dc7512..e8d43cdc10e 100644 --- a/chromium/net/quic/test_tools/quic_framer_peer.cc +++ b/chromium/net/quic/test_tools/quic_framer_peer.cc @@ -34,9 +34,5 @@ void QuicFramerPeer::SetIsServer(QuicFramer* framer, bool is_server) { framer->is_server_ = is_server; } -void QuicFramerPeer::SetVersion(QuicFramer* framer, QuicVersion version) { - framer->quic_version_ = version; -} - } // namespace test } // namespace net diff --git a/chromium/net/quic/test_tools/quic_framer_peer.h b/chromium/net/quic/test_tools/quic_framer_peer.h index 0508f5c7a7e..acb45ecb173 100644 --- a/chromium/net/quic/test_tools/quic_framer_peer.h +++ b/chromium/net/quic/test_tools/quic_framer_peer.h @@ -24,7 +24,6 @@ class QuicFramerPeer { QuicFramer* framer, QuicPacketSequenceNumber packet_sequence_number); static void SetIsServer(QuicFramer* framer, bool is_server); - static void SetVersion(QuicFramer* framer, QuicVersion version); private: DISALLOW_COPY_AND_ASSIGN(QuicFramerPeer); diff --git a/chromium/net/quic/test_tools/quic_packet_creator_peer.cc b/chromium/net/quic/test_tools/quic_packet_creator_peer.cc index 4451f02be84..1acdf244f2f 100644 --- a/chromium/net/quic/test_tools/quic_packet_creator_peer.cc +++ b/chromium/net/quic/test_tools/quic_packet_creator_peer.cc @@ -21,6 +21,19 @@ void QuicPacketCreatorPeer::SetSendVersionInPacket( } // static +void QuicPacketCreatorPeer::SetSequenceNumberLength( + QuicPacketCreator* creator, + QuicSequenceNumberLength sequence_number_length) { + creator->sequence_number_length_ = sequence_number_length; +} + +// static +QuicSequenceNumberLength QuicPacketCreatorPeer::GetSequenceNumberLength( + QuicPacketCreator* creator) { + return creator->sequence_number_length_; +} + +// static void QuicPacketCreatorPeer::SetIsServer(QuicPacketCreator* creator, bool is_server) { creator->is_server_ = is_server; diff --git a/chromium/net/quic/test_tools/quic_packet_creator_peer.h b/chromium/net/quic/test_tools/quic_packet_creator_peer.h index 816afa96189..12ae676664b 100644 --- a/chromium/net/quic/test_tools/quic_packet_creator_peer.h +++ b/chromium/net/quic/test_tools/quic_packet_creator_peer.h @@ -18,6 +18,11 @@ class QuicPacketCreatorPeer { static void SetSendVersionInPacket(QuicPacketCreator* creator, bool send_version_in_packet); + static void SetSequenceNumberLength( + QuicPacketCreator* creator, + QuicSequenceNumberLength sequence_number_length); + static QuicSequenceNumberLength GetSequenceNumberLength( + QuicPacketCreator* creator); static void SetIsServer(QuicPacketCreator* creator, bool is_server); diff --git a/chromium/net/quic/test_tools/quic_session_peer.cc b/chromium/net/quic/test_tools/quic_session_peer.cc index 66caa15a4b6..c25b42fb216 100644 --- a/chromium/net/quic/test_tools/quic_session_peer.cc +++ b/chromium/net/quic/test_tools/quic_session_peer.cc @@ -22,13 +22,7 @@ void QuicSessionPeer::SetMaxOpenStreams(QuicSession* session, } // static -ReliableQuicStream* QuicSessionPeer::CreateIncomingReliableStream( - QuicSession* session, QuicStreamId id) { - return session->CreateIncomingReliableStream(id); -} - -// static -BlockedList<QuicStreamId>* QuicSessionPeer::GetWriteblockedStreams( +WriteBlockedList<QuicStreamId>* QuicSessionPeer::GetWriteblockedStreams( QuicSession* session) { return &session->write_blocked_streams_; } diff --git a/chromium/net/quic/test_tools/quic_session_peer.h b/chromium/net/quic/test_tools/quic_session_peer.h index 6f9a8f394e2..fb4529cc0ab 100644 --- a/chromium/net/quic/test_tools/quic_session_peer.h +++ b/chromium/net/quic/test_tools/quic_session_peer.h @@ -5,8 +5,8 @@ #ifndef NET_QUIC_TEST_TOOLS_QUIC_SESSION_PEER_H_ #define NET_QUIC_TEST_TOOLS_QUIC_SESSION_PEER_H_ -#include "net/quic/blocked_list.h" #include "net/quic/quic_protocol.h" +#include "net/spdy/write_blocked_list.h" namespace net { @@ -19,9 +19,7 @@ class QuicSessionPeer { public: static void SetNextStreamId(QuicSession* session, QuicStreamId id); static void SetMaxOpenStreams(QuicSession* session, uint32 max_streams); - static ReliableQuicStream* CreateIncomingReliableStream(QuicSession* session, - QuicStreamId id); - static BlockedList<QuicStreamId>* GetWriteblockedStreams( + static WriteBlockedList<QuicStreamId>* GetWriteblockedStreams( QuicSession* session); private: diff --git a/chromium/net/quic/test_tools/quic_test_utils.cc b/chromium/net/quic/test_tools/quic_test_utils.cc index 2562b07d599..8ba46b8f4ac 100644 --- a/chromium/net/quic/test_tools/quic_test_utils.cc +++ b/chromium/net/quic/test_tools/quic_test_utils.cc @@ -243,8 +243,9 @@ bool PacketSavingConnection::SendOrQueuePacket( EncryptionLevel level, QuicPacketSequenceNumber sequence_number, QuicPacket* packet, - QuicPacketEntropyHash entropy_hash, - HasRetransmittableData retransmittable) { + QuicPacketEntropyHash /* entropy_hash */, + HasRetransmittableData /* retransmittable */, + Force /* forced */) { packets_.push_back(packet); QuicEncryptedPacket* encrypted = framer_.EncryptPacket(level, sequence_number, *packet); @@ -254,7 +255,7 @@ bool PacketSavingConnection::SendOrQueuePacket( MockSession::MockSession(QuicConnection* connection, bool is_server) : QuicSession(connection, DefaultQuicConfig(), is_server) { - ON_CALL(*this, WriteData(_, _, _, _)) + ON_CALL(*this, WritevData(_, _, _, _, _)) .WillByDefault(testing::Return(QuicConsumedData(0, false))); } @@ -284,6 +285,12 @@ MockSendAlgorithm::MockSendAlgorithm() { MockSendAlgorithm::~MockSendAlgorithm() { } +MockAckNotifierDelegate::MockAckNotifierDelegate() { +} + +MockAckNotifierDelegate::~MockAckNotifierDelegate() { +} + namespace { string HexDumpWithMarks(const char* data, int length, @@ -356,16 +363,6 @@ void CompareCharArraysWithHexError( << HexDumpWithMarks(actual, actual_len, marks.get(), max_len); } -void CompareQuicDataWithHexError( - const string& description, - QuicData* actual, - QuicData* expected) { - CompareCharArraysWithHexError( - description, - actual->data(), actual->length(), - expected->data(), expected->length()); -} - static QuicPacket* ConstructPacketFromHandshakeMessage( QuicGuid guid, const CryptoHandshakeMessage& message, @@ -399,20 +396,22 @@ QuicPacket* ConstructHandshakePacket(QuicGuid guid, QuicTag tag) { return ConstructPacketFromHandshakeMessage(guid, message, false); } -size_t GetPacketLengthForOneStream(QuicVersion version, - bool include_version, - InFecGroup is_in_fec_group, - size_t* payload_length) { +size_t GetPacketLengthForOneStream( + QuicVersion version, + bool include_version, + QuicSequenceNumberLength sequence_number_length, + InFecGroup is_in_fec_group, + size_t* payload_length) { *payload_length = 1; const size_t stream_length = NullEncrypter().GetCiphertextSize(*payload_length) + QuicPacketCreator::StreamFramePacketOverhead( version, PACKET_8BYTE_GUID, include_version, - PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group); + sequence_number_length, is_in_fec_group); const size_t ack_length = NullEncrypter().GetCiphertextSize( QuicFramer::GetMinAckFrameSize()) + GetPacketHeaderSize(PACKET_8BYTE_GUID, include_version, - PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group); + sequence_number_length, is_in_fec_group); if (stream_length < ack_length) { *payload_length = 1 + ack_length - stream_length; } @@ -420,7 +419,7 @@ size_t GetPacketLengthForOneStream(QuicVersion version, return NullEncrypter().GetCiphertextSize(*payload_length) + QuicPacketCreator::StreamFramePacketOverhead( version, PACKET_8BYTE_GUID, include_version, - PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group); + sequence_number_length, is_in_fec_group); } // Size in bytes of the stream frame fields for an arbitrary StreamID and diff --git a/chromium/net/quic/test_tools/quic_test_utils.h b/chromium/net/quic/test_tools/quic_test_utils.h index 65fba73d3b9..64a9d30ada2 100644 --- a/chromium/net/quic/test_tools/quic_test_utils.h +++ b/chromium/net/quic/test_tools/quic_test_utils.h @@ -31,24 +31,20 @@ void CompareCharArraysWithHexError(const std::string& description, const char* expected, const int expected_len); -void CompareQuicDataWithHexError(const std::string& description, - QuicData* actual, - QuicData* expected); - // Returns the length of a QuicPacket that is capable of holding either a // stream frame or a minimal ack frame. Sets |*payload_length| to the number // of bytes of stream data that will fit in such a packet. -size_t GetPacketLengthForOneStream(QuicVersion version, - bool include_version, - InFecGroup is_in_fec_group, - size_t* payload_length); +size_t GetPacketLengthForOneStream( + QuicVersion version, + bool include_version, + QuicSequenceNumberLength sequence_number_length, + InFecGroup is_in_fec_group, + size_t* payload_length); // Size in bytes of the stream frame fields for an arbitrary StreamID and // offset and the last frame in a packet. size_t GetMinStreamFrameSize(QuicVersion version); -string SerializeUncompressedHeaders(const SpdyHeaderBlock& headers); - // Returns QuicConfig set to default values. QuicConfig DefaultQuicConfig(); @@ -179,15 +175,14 @@ class MockConnectionVisitor : public QuicConnectionVisitorInterface { MockConnectionVisitor(); virtual ~MockConnectionVisitor(); - MOCK_METHOD4(OnPacket, bool(const IPEndPoint& self_address, - const IPEndPoint& peer_address, - const QuicPacketHeader& header, - const std::vector<QuicStreamFrame>& frame)); + MOCK_METHOD1(OnStreamFrames, bool(const std::vector<QuicStreamFrame>& frame)); MOCK_METHOD1(OnRstStream, void(const QuicRstStreamFrame& frame)); MOCK_METHOD1(OnGoAway, void(const QuicGoAwayFrame& frame)); MOCK_METHOD2(ConnectionClose, void(QuicErrorCode error, bool from_peer)); - MOCK_METHOD1(OnAck, void(const SequenceNumberSet& acked_packets)); MOCK_METHOD0(OnCanWrite, bool()); + MOCK_CONST_METHOD0(HasPendingHandshake, bool()); + MOCK_METHOD1(OnSuccessfulVersionNegotiation, + void(const QuicVersion& version)); private: DISALLOW_COPY_AND_ASSIGN(MockConnectionVisitor); @@ -205,7 +200,7 @@ class MockHelper : public QuicConnectionHelperInterface { MOCK_METHOD2(WritePacketToWire, int(const QuicEncryptedPacket& packet, int* error)); MOCK_METHOD0(IsWriteBlockedDataBuffered, bool()); - MOCK_METHOD1(IsWriteBlocked, bool(int)); + MOCK_METHOD1(IsWriteBlocked, bool(int stream_id)); virtual QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate); private: @@ -239,6 +234,7 @@ class MockConnection : public QuicConnection { QuicStreamId last_good_stream_id, const string& reason)); MOCK_METHOD0(OnCanWrite, bool()); + MOCK_CONST_METHOD0(HasPendingHandshake, bool()); void ProcessUdpPacketInternal(const IPEndPoint& self_address, const IPEndPoint& peer_address, @@ -266,7 +262,8 @@ class PacketSavingConnection : public MockConnection { QuicPacketSequenceNumber sequence_number, QuicPacket* packet, QuicPacketEntropyHash entropy_hash, - HasRetransmittableData has_retransmittable_data) OVERRIDE; + HasRetransmittableData has_retransmittable_data, + Force forced) OVERRIDE; std::vector<QuicPacket*> packets_; std::vector<QuicEncryptedPacket*> encrypted_packets_; @@ -289,10 +286,11 @@ class MockSession : public QuicSession { ReliableQuicStream*(QuicStreamId id)); MOCK_METHOD0(GetCryptoStream, QuicCryptoStream*()); MOCK_METHOD0(CreateOutgoingReliableStream, ReliableQuicStream*()); - MOCK_METHOD4(WriteData, QuicConsumedData(QuicStreamId id, - base::StringPiece data, - QuicStreamOffset offset, - bool fin)); + MOCK_METHOD5(WritevData, QuicConsumedData(QuicStreamId id, + const struct iovec* iov, + int count, + QuicStreamOffset offset, + bool fin)); MOCK_METHOD0(IsHandshakeComplete, bool()); private: @@ -331,8 +329,9 @@ class MockSendAlgorithm : public SendAlgorithmInterface { MOCK_METHOD3(OnIncomingAck, void(QuicPacketSequenceNumber, QuicByteCount, QuicTime::Delta)); MOCK_METHOD1(OnIncomingLoss, void(QuicTime)); - MOCK_METHOD4(SentPacket, void(QuicTime sent_time, QuicPacketSequenceNumber, - QuicByteCount, Retransmission)); + MOCK_METHOD5(SentPacket, + bool(QuicTime sent_time, QuicPacketSequenceNumber, QuicByteCount, + Retransmission, HasRetransmittableData)); MOCK_METHOD2(AbandoningPacket, void(QuicPacketSequenceNumber sequence_number, QuicByteCount abandoned_bytes)); MOCK_METHOD4(TimeUntilSend, QuicTime::Delta(QuicTime now, Retransmission, @@ -370,6 +369,14 @@ class TestDecompressorVisitor : public QuicSpdyDecompressor::Visitor { bool error_; }; +class MockAckNotifierDelegate : public QuicAckNotifier::DelegateInterface { + public: + MockAckNotifierDelegate(); + virtual ~MockAckNotifierDelegate(); + + MOCK_METHOD0(OnAckNotification, void()); +}; + } // namespace test } // namespace net diff --git a/chromium/net/server/http_connection.cc b/chromium/net/server/http_connection.cc index d964cb0738b..d433012cd65 100644 --- a/chromium/net/server/http_connection.cc +++ b/chromium/net/server/http_connection.cc @@ -29,21 +29,17 @@ void HttpConnection::Send(const HttpServerResponseInfo& response) { Send(response.Serialize()); } -HttpConnection::HttpConnection(HttpServer* server, StreamListenSocket* sock) +HttpConnection::HttpConnection(HttpServer* server, + scoped_ptr<StreamListenSocket> sock) : server_(server), - socket_(sock) { + socket_(sock.Pass()) { id_ = last_id_++; } HttpConnection::~HttpConnection() { - DetachSocket(); server_->delegate_->OnClose(id_); } -void HttpConnection::DetachSocket() { - socket_ = NULL; -} - void HttpConnection::Shift(int num_bytes) { recv_data_ = recv_data_.substr(num_bytes); } diff --git a/chromium/net/server/http_connection.h b/chromium/net/server/http_connection.h index b0e37663d4f..17faa46eb6b 100644 --- a/chromium/net/server/http_connection.h +++ b/chromium/net/server/http_connection.h @@ -8,7 +8,6 @@ #include <string> #include "base/basictypes.h" -#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "net/http/http_status_code.h" @@ -36,12 +35,10 @@ class HttpConnection { friend class HttpServer; static int last_id_; - HttpConnection(HttpServer* server, StreamListenSocket* sock); - - void DetachSocket(); + HttpConnection(HttpServer* server, scoped_ptr<StreamListenSocket> sock); HttpServer* server_; - scoped_refptr<StreamListenSocket> socket_; + scoped_ptr<StreamListenSocket> socket_; scoped_ptr<WebSocket> web_socket_; std::string recv_data_; int id_; diff --git a/chromium/net/server/http_server.cc b/chromium/net/server/http_server.cc index 373025c4aa6..a51feb84401 100644 --- a/chromium/net/server/http_server.cc +++ b/chromium/net/server/http_server.cc @@ -95,10 +95,11 @@ int HttpServer::GetLocalAddress(IPEndPoint* address) { } void HttpServer::DidAccept(StreamListenSocket* server, - StreamListenSocket* socket) { - HttpConnection* connection = new HttpConnection(this, socket); + scoped_ptr<StreamListenSocket> socket) { + HttpConnection* connection = new HttpConnection(this, socket.Pass()); id_to_connection_[connection->id()] = connection; - socket_to_connection_[socket] = connection; + // TODO(szym): Fix socket access. Make HttpConnection the Delegate. + socket_to_connection_[connection->socket_.get()] = connection; } void HttpServer::DidRead(StreamListenSocket* socket, @@ -180,7 +181,6 @@ void HttpServer::DidClose(StreamListenSocket* socket) { HttpServer::~HttpServer() { STLDeleteContainerPairSecondPointers( id_to_connection_.begin(), id_to_connection_.end()); - server_ = NULL; } // diff --git a/chromium/net/server/http_server.h b/chromium/net/server/http_server.h index f4345752e2d..51bec956889 100644 --- a/chromium/net/server/http_server.h +++ b/chromium/net/server/http_server.h @@ -9,7 +9,7 @@ #include <map> #include "base/basictypes.h" -#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" #include "net/http/http_status_code.h" #include "net/socket/stream_listen_socket.h" @@ -65,7 +65,7 @@ class HttpServer : public StreamListenSocket::Delegate, // ListenSocketDelegate virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* socket) OVERRIDE; + scoped_ptr<StreamListenSocket> socket) OVERRIDE; virtual void DidRead(StreamListenSocket* socket, const char* data, int len) OVERRIDE; @@ -89,7 +89,7 @@ class HttpServer : public StreamListenSocket::Delegate, HttpConnection* FindConnection(StreamListenSocket* socket); HttpServer::Delegate* delegate_; - scoped_refptr<StreamListenSocket> server_; + scoped_ptr<StreamListenSocket> server_; typedef std::map<int, HttpConnection*> IdToConnectionMap; IdToConnectionMap id_to_connection_; typedef std::map<StreamListenSocket*, HttpConnection*> SocketToConnectionMap; diff --git a/chromium/net/server/http_server_unittest.cc b/chromium/net/server/http_server_unittest.cc index 48a2ce7571f..ede5066c3cd 100644 --- a/chromium/net/server/http_server_unittest.cc +++ b/chromium/net/server/http_server_unittest.cc @@ -276,9 +276,9 @@ class MockStreamListenSocket : public StreamListenSocket { } // namespace TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { - scoped_refptr<StreamListenSocket> socket( - new MockStreamListenSocket(server_.get())); - server_->DidAccept(NULL, socket.get()); + StreamListenSocket* socket = + new MockStreamListenSocket(server_.get()); + server_->DidAccept(NULL, make_scoped_ptr(socket)); std::string body("body"); std::string request = base::StringPrintf( "GET /test HTTP/1.1\r\n" @@ -286,9 +286,9 @@ TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { "Content-Length: %" PRIuS "\r\n\r\n%s", body.length(), body.c_str()); - server_->DidRead(socket.get(), request.c_str(), request.length() - 2); + server_->DidRead(socket, request.c_str(), request.length() - 2); ASSERT_EQ(0u, requests_.size()); - server_->DidRead(socket.get(), request.c_str() + request.length() - 2, 2); + server_->DidRead(socket, request.c_str() + request.length() - 2, 2); ASSERT_EQ(1u, requests_.size()); ASSERT_EQ(body, requests_[0].data); } diff --git a/chromium/net/socket/buffered_write_stream_socket.cc b/chromium/net/socket/buffered_write_stream_socket.cc index 36b9df715fd..cf13c5e439a 100644 --- a/chromium/net/socket/buffered_write_stream_socket.cc +++ b/chromium/net/socket/buffered_write_stream_socket.cc @@ -23,8 +23,8 @@ void AppendBuffer(GrowableIOBuffer* dst, IOBuffer* src, int src_len) { } // anonymous namespace BufferedWriteStreamSocket::BufferedWriteStreamSocket( - StreamSocket* socket_to_wrap) - : wrapped_socket_(socket_to_wrap), + scoped_ptr<StreamSocket> socket_to_wrap) + : wrapped_socket_(socket_to_wrap.Pass()), io_buffer_(new GrowableIOBuffer()), backup_buffer_(new GrowableIOBuffer()), weak_factory_(this), diff --git a/chromium/net/socket/buffered_write_stream_socket.h b/chromium/net/socket/buffered_write_stream_socket.h index fcb33a81910..aad5736d0b0 100644 --- a/chromium/net/socket/buffered_write_stream_socket.h +++ b/chromium/net/socket/buffered_write_stream_socket.h @@ -5,6 +5,8 @@ #ifndef NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ #define NET_SOCKET_BUFFERED_WRITE_STREAM_SOCKET_H_ +#include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" #include "net/base/net_log.h" #include "net/socket/stream_socket.h" @@ -33,7 +35,7 @@ class IPEndPoint; // There are no bounds on the local buffer size. Use carefully. class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { public: - BufferedWriteStreamSocket(StreamSocket* socket_to_wrap); + explicit BufferedWriteStreamSocket(scoped_ptr<StreamSocket> socket_to_wrap); virtual ~BufferedWriteStreamSocket(); // Socket interface @@ -71,6 +73,8 @@ class NET_EXPORT_PRIVATE BufferedWriteStreamSocket : public StreamSocket { bool callback_pending_; bool wrapped_write_in_progress_; int error_; + + DISALLOW_COPY_AND_ASSIGN(BufferedWriteStreamSocket); }; } // namespace net diff --git a/chromium/net/socket/buffered_write_stream_socket_unittest.cc b/chromium/net/socket/buffered_write_stream_socket_unittest.cc index e579a7f51d2..485295f33f6 100644 --- a/chromium/net/socket/buffered_write_stream_socket_unittest.cc +++ b/chromium/net/socket/buffered_write_stream_socket_unittest.cc @@ -30,10 +30,11 @@ class BufferedWriteStreamSocketTest : public testing::Test { if (writes_count) { data_->StopAfter(writes_count); } - DeterministicMockTCPClientSocket* wrapped_socket = - new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get()); + scoped_ptr<DeterministicMockTCPClientSocket> wrapped_socket( + new DeterministicMockTCPClientSocket(net_log_.net_log(), data_.get())); data_->set_delegate(wrapped_socket->AsWeakPtr()); - socket_.reset(new BufferedWriteStreamSocket(wrapped_socket)); + socket_.reset(new BufferedWriteStreamSocket( + wrapped_socket.PassAs<StreamSocket>())); socket_->Connect(callback_.callback()); } diff --git a/chromium/net/socket/client_socket_factory.cc b/chromium/net/socket/client_socket_factory.cc index 022988aa6a9..a86688e3333 100644 --- a/chromium/net/socket/client_socket_factory.cc +++ b/chromium/net/socket/client_socket_factory.cc @@ -67,23 +67,25 @@ class DefaultClientSocketFactory : public ClientSocketFactory, ClearSSLSessionCache(); } - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { - return new UDPClientSocket(bind_type, rand_int_cb, net_log, source); + return scoped_ptr<DatagramClientSocket>( + new UDPClientSocket(bind_type, rand_int_cb, net_log, source)); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE { - return new TCPClientSocket(addresses, net_log, source); + return scoped_ptr<StreamSocket>( + new TCPClientSocket(addresses, net_log, source)); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { @@ -102,17 +104,19 @@ class DefaultClientSocketFactory : public ClientSocketFactory, nss_task_runner = base::ThreadTaskRunnerHandle::Get(); #if defined(USE_OPENSSL) - return new SSLClientSocketOpenSSL(transport_socket, host_and_port, - ssl_config, context); + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketOpenSSL(transport_socket.Pass(), host_and_port, + ssl_config, context)); #elif defined(USE_NSS) || defined(OS_MACOSX) || defined(OS_WIN) - return new SSLClientSocketNSS(nss_task_runner.get(), - transport_socket, - host_and_port, - ssl_config, - context); + return scoped_ptr<SSLClientSocket>( + new SSLClientSocketNSS(nss_task_runner.get(), + transport_socket.Pass(), + host_and_port, + ssl_config, + context)); #else NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); #endif } @@ -130,18 +134,6 @@ static base::LazyInstance<DefaultClientSocketFactory>::Leaky } // namespace -// Deprecated function (http://crbug.com/37810) that takes a StreamSocket. -SSLClientSocket* ClientSocketFactory::CreateSSLClientSocket( - StreamSocket* transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - const SSLClientSocketContext& context) { - ClientSocketHandle* socket_handle = new ClientSocketHandle(); - socket_handle->set_socket(transport_socket); - return CreateSSLClientSocket(socket_handle, host_and_port, ssl_config, - context); -} - // static ClientSocketFactory* ClientSocketFactory::GetDefaultFactory() { return g_default_client_socket_factory.Pointer(); diff --git a/chromium/net/socket/client_socket_factory.h b/chromium/net/socket/client_socket_factory.h index 65022f29234..6cb5949f0b3 100644 --- a/chromium/net/socket/client_socket_factory.h +++ b/chromium/net/socket/client_socket_factory.h @@ -8,6 +8,7 @@ #include <string> #include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" #include "net/base/net_log.h" #include "net/base/rand_callback.h" @@ -32,13 +33,13 @@ class NET_EXPORT ClientSocketFactory { // |source| is the NetLog::Source for the entity trying to create the socket, // if it has one. - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) = 0; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) = 0; @@ -46,19 +47,12 @@ class NET_EXPORT ClientSocketFactory { // It is allowed to pass in a |transport_socket| that is not obtained from a // socket pool. The caller could create a ClientSocketHandle directly and call // set_socket() on it to set a valid StreamSocket instance. - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) = 0; - // Deprecated function (http://crbug.com/37810) that takes a StreamSocket. - virtual SSLClientSocket* CreateSSLClientSocket( - StreamSocket* transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - const SSLClientSocketContext& context); - // Clears cache used for SSL session resumption. virtual void ClearSSLSessionCache() = 0; diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc index 3894fa7aa0e..e42e9fcada3 100644 --- a/chromium/net/socket/client_socket_handle.cc +++ b/chromium/net/socket/client_socket_handle.cc @@ -18,7 +18,7 @@ namespace net { ClientSocketHandle::ClientSocketHandle() : is_initialized_(false), pool_(NULL), - layered_pool_(NULL), + higher_pool_(NULL), is_reused_(false), callback_(base::Bind(&ClientSocketHandle::OnIOComplete, base::Unretained(this))), @@ -34,29 +34,34 @@ void ClientSocketHandle::Reset() { } void ClientSocketHandle::ResetInternal(bool cancel) { - if (group_name_.empty()) // Was Init called? - return; - if (is_initialized()) { - // Because of http://crbug.com/37810 we may not have a pool, but have - // just a raw socket. - socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE); - if (pool_) - // If we've still got a socket, release it back to the ClientSocketPool so - // it can be deleted or reused. - pool_->ReleaseSocket(group_name_, release_socket(), pool_id_); - } else if (cancel) { - // If we did not get initialized yet, we've got a socket request pending. - // Cancel it. - pool_->CancelRequest(group_name_, this); + // Was Init called? + if (!group_name_.empty()) { + // If so, we must have a pool. + CHECK(pool_); + if (is_initialized()) { + if (socket_) { + socket_->NetLog().EndEvent(NetLog::TYPE_SOCKET_IN_USE); + // Release the socket back to the ClientSocketPool so it can be + // deleted or reused. + pool_->ReleaseSocket(group_name_, socket_.Pass(), pool_id_); + } else { + // If the handle has been initialized, we should still have a + // socket. + NOTREACHED(); + } + } else if (cancel) { + // If we did not get initialized yet and we have a socket + // request pending, cancel it. + pool_->CancelRequest(group_name_, this); + } } is_initialized_ = false; + socket_.reset(); group_name_.clear(); is_reused_ = false; user_callback_.Reset(); - if (layered_pool_) { - pool_->RemoveLayeredPool(layered_pool_); - layered_pool_ = NULL; - } + if (higher_pool_) + RemoveHigherLayeredPool(higher_pool_); pool_ = NULL; idle_time_ = base::TimeDelta(); init_time_ = base::TimeTicks(); @@ -82,24 +87,30 @@ LoadState ClientSocketHandle::GetLoadState() const { } bool ClientSocketHandle::IsPoolStalled() const { + if (!pool_) + return false; return pool_->IsStalled(); } -void ClientSocketHandle::AddLayeredPool(LayeredPool* layered_pool) { - CHECK(layered_pool); - CHECK(!layered_pool_); +void ClientSocketHandle::AddHigherLayeredPool(HigherLayeredPool* higher_pool) { + CHECK(higher_pool); + CHECK(!higher_pool_); + // TODO(mmenke): |pool_| should only be NULL in tests. Maybe stop doing that + // so this be be made into a DCHECK, and the same can be done in + // RemoveHigherLayeredPool? if (pool_) { - pool_->AddLayeredPool(layered_pool); - layered_pool_ = layered_pool; + pool_->AddHigherLayeredPool(higher_pool); + higher_pool_ = higher_pool; } } -void ClientSocketHandle::RemoveLayeredPool(LayeredPool* layered_pool) { - CHECK(layered_pool); - CHECK(layered_pool_); +void ClientSocketHandle::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + CHECK(higher_pool_); + CHECK_EQ(higher_pool_, higher_pool); if (pool_) { - pool_->RemoveLayeredPool(layered_pool); - layered_pool_ = NULL; + pool_->RemoveHigherLayeredPool(higher_pool); + higher_pool_ = NULL; } } @@ -121,6 +132,10 @@ bool ClientSocketHandle::GetLoadTimingInfo( return true; } +void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) { + socket_ = s.Pass(); +} + void ClientSocketHandle::OnIOComplete(int result) { CompletionCallback callback = user_callback_; user_callback_.Reset(); @@ -128,6 +143,10 @@ void ClientSocketHandle::OnIOComplete(int result) { callback.Run(result); } +scoped_ptr<StreamSocket> ClientSocketHandle::PassSocket() { + return socket_.Pass(); +} + void ClientSocketHandle::HandleInitCompletion(int result) { CHECK_NE(ERR_IO_PENDING, result); ClientSocketPoolHistograms* histograms = pool_->histograms(); diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h index 7d5588a145b..30b7c03e9dc 100644 --- a/chromium/net/socket/client_socket_handle.h +++ b/chromium/net/socket/client_socket_handle.h @@ -70,9 +70,9 @@ class NET_EXPORT ClientSocketHandle { // // Profiling information for the request is saved to |net_log| if non-NULL. // - template <typename SocketParams, typename PoolType> + template <typename PoolType> int Init(const std::string& group_name, - const scoped_refptr<SocketParams>& socket_params, + const scoped_refptr<typename PoolType::SocketParams>& socket_params, RequestPriority priority, const CompletionCallback& callback, PoolType* pool, @@ -94,9 +94,15 @@ class NET_EXPORT ClientSocketHandle { bool IsPoolStalled() const; - void AddLayeredPool(LayeredPool* layered_pool); + // Adds a higher layered pool on top of the socket pool that |socket_| belongs + // to. At most one higher layered pool can be added to a + // ClientSocketHandle at a time. On destruction or reset, automatically + // removes the higher pool if RemoveHigherLayeredPool has not been called. + void AddHigherLayeredPool(HigherLayeredPool* higher_pool); - void RemoveLayeredPool(LayeredPool* layered_pool); + // Removes a higher layered pool from the socket pool that |socket_| belongs + // to. |higher_pool| must have been added by the above function. + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool); // Returns true when Init() has completed successfully. bool is_initialized() const { return is_initialized_; } @@ -116,8 +122,11 @@ class NET_EXPORT ClientSocketHandle { LoadTimingInfo* load_timing_info) const; // Used by ClientSocketPool to initialize the ClientSocketHandle. + // + // SetSocket() may also be used if this handle is used as simply for + // socket storage (e.g., http://crbug.com/37810). + void SetSocket(scoped_ptr<StreamSocket> s); void set_is_reused(bool is_reused) { is_reused_ = is_reused; } - void set_socket(StreamSocket* s) { socket_.reset(s); } void set_idle_time(base::TimeDelta idle_time) { idle_time_ = idle_time; } void set_pool_id(int id) { pool_id_ = id; } void set_is_ssl_error(bool is_ssl_error) { is_ssl_error_ = is_ssl_error; } @@ -143,11 +152,15 @@ class NET_EXPORT ClientSocketHandle { return pending_http_proxy_connection_.release(); } + StreamSocket* socket() { return socket_.get(); } + + // SetSocket() must be called with a new socket before this handle + // is destroyed if is_initialized() is true. + scoped_ptr<StreamSocket> PassSocket(); + // These may only be used if is_initialized() is true. const std::string& group_name() const { return group_name_; } int id() const { return pool_id_; } - StreamSocket* socket() { return socket_.get(); } - StreamSocket* release_socket() { return socket_.release(); } bool is_reused() const { return is_reused_; } base::TimeDelta idle_time() const { return idle_time_; } SocketReuseType reuse_type() const { @@ -184,7 +197,7 @@ class NET_EXPORT ClientSocketHandle { bool is_initialized_; ClientSocketPool* pool_; - LayeredPool* layered_pool_; + HigherLayeredPool* higher_pool_; scoped_ptr<StreamSocket> socket_; std::string group_name_; bool is_reused_; @@ -207,20 +220,17 @@ class NET_EXPORT ClientSocketHandle { }; // Template function implementation: -template <typename SocketParams, typename PoolType> -int ClientSocketHandle::Init(const std::string& group_name, - const scoped_refptr<SocketParams>& socket_params, - RequestPriority priority, - const CompletionCallback& callback, - PoolType* pool, - const BoundNetLog& net_log) { +template <typename PoolType> +int ClientSocketHandle::Init( + const std::string& group_name, + const scoped_refptr<typename PoolType::SocketParams>& socket_params, + RequestPriority priority, + const CompletionCallback& callback, + PoolType* pool, + const BoundNetLog& net_log) { requesting_source_ = net_log.source(); CHECK(!group_name.empty()); - // Note that this will result in a compile error if the SocketParams has not - // been registered for the PoolType via REGISTER_SOCKET_PARAMS_FOR_POOL - // (defined in client_socket_pool.h). - CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); ResetInternal(true); ResetErrorState(); pool_ = pool; diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h index 7cb9a7ebc2e..715cddb94d4 100644 --- a/chromium/net/socket/client_socket_pool.h +++ b/chromium/net/socket/client_socket_pool.h @@ -10,6 +10,7 @@ #include "base/basictypes.h" #include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" #include "base/template_util.h" #include "base/time/time.h" #include "net/base/completion_callback.h" @@ -30,20 +31,43 @@ class StreamSocket; // ClientSocketPools are layered. This defines an interface for lower level // socket pools to communicate with higher layer pools. -class NET_EXPORT LayeredPool { +class NET_EXPORT HigherLayeredPool { public: - virtual ~LayeredPool() {}; + virtual ~HigherLayeredPool() {} - // Instructs the LayeredPool to close an idle connection. Return true if one - // was closed. + // Instructs the HigherLayeredPool to close an idle connection. Return true if + // one was closed. Closing an idle connection will call into the lower layer + // pool it came from, so must be careful of re-entrancy when using this. virtual bool CloseOneIdleConnection() = 0; }; +// ClientSocketPools are layered. This defines an interface for higher level +// socket pools to communicate with lower layer pools. +class NET_EXPORT LowerLayeredPool { + public: + virtual ~LowerLayeredPool() {} + + // Returns true if a there is currently a request blocked on the per-pool + // (not per-host) max socket limit, either in this pool, or one that it is + // layered on top of. + virtual bool IsStalled() const = 0; + + // Called to add or remove a higher layer pool on top of |this|. A higher + // layer pool may be added at most once to |this|, and must be removed prior + // to destruction of |this|. + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) = 0; + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) = 0; +}; + // A ClientSocketPool is used to restrict the number of sockets open at a time. // It also maintains a list of idle persistent sockets. // -class NET_EXPORT ClientSocketPool { +class NET_EXPORT ClientSocketPool : public LowerLayeredPool { public: + // Subclasses must also have an inner class SocketParams which is + // the type for the |params| argument in RequestSocket() and + // RequestSockets() below. + // Requests a connected socket for a group_name. // // There are five possible results from calling this function: @@ -111,7 +135,7 @@ class NET_EXPORT ClientSocketPool { // change when it flushes, so it can use this |id| to discard sockets with // mismatched ids. virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) = 0; // This flushes all state from the ClientSocketPool. This means that all @@ -121,10 +145,6 @@ class NET_EXPORT ClientSocketPool { // Does not flush any pools wrapped by |this|. virtual void FlushWithError(int error) = 0; - // Returns true if a there is currently a request blocked on the - // per-pool (not per-host) max socket limit. - virtual bool IsStalled() const = 0; - // Called to close any idle connections held by the connection manager. virtual void CloseIdleSockets() = 0; @@ -138,12 +158,6 @@ class NET_EXPORT ClientSocketPool { virtual LoadState GetLoadState(const std::string& group_name, const ClientSocketHandle* handle) const = 0; - // Adds a LayeredPool on top of |this|. - virtual void AddLayeredPool(LayeredPool* layered_pool) = 0; - - // Removes a LayeredPool from |this|. - virtual void RemoveLayeredPool(LayeredPool* layered_pool) = 0; - // Retrieves information on the current state of the pool as a // DictionaryValue. Caller takes possession of the returned value. // If |include_nested_pools| is true, the states of any nested @@ -177,41 +191,13 @@ class NET_EXPORT ClientSocketPool { DISALLOW_COPY_AND_ASSIGN(ClientSocketPool); }; -// ClientSocketPool subclasses should indicate valid SocketParams via the -// REGISTER_SOCKET_PARAMS_FOR_POOL macro below. By default, any given -// <PoolType,SocketParams> pair will have its SocketParamsTrait inherit from -// base::false_type, but REGISTER_SOCKET_PARAMS_FOR_POOL will specialize that -// pairing to inherit from base::true_type. This provides compile time -// verification that the correct SocketParams type is used with the appropriate -// PoolType. -template <typename PoolType, typename SocketParams> -struct SocketParamTraits : public base::false_type { -}; - -template <typename PoolType, typename SocketParams> -void CheckIsValidSocketParamsForPool() { - COMPILE_ASSERT(!base::is_pointer<scoped_refptr<SocketParams> >::value, - socket_params_cannot_be_pointer); - COMPILE_ASSERT((SocketParamTraits<PoolType, - scoped_refptr<SocketParams> >::value), - invalid_socket_params_for_pool); -} - -// Provides an empty definition for CheckIsValidSocketParamsForPool() which -// should be optimized out by the compiler. -#define REGISTER_SOCKET_PARAMS_FOR_POOL(pool_type, socket_params) \ -template<> \ -struct SocketParamTraits<pool_type, scoped_refptr<socket_params> > \ - : public base::true_type { \ -} - -template <typename PoolType, typename SocketParams> -void RequestSocketsForPool(PoolType* pool, - const std::string& group_name, - const scoped_refptr<SocketParams>& params, - int num_sockets, - const BoundNetLog& net_log) { - CheckIsValidSocketParamsForPool<PoolType, SocketParams>(); +template <typename PoolType> +void RequestSocketsForPool( + PoolType* pool, + const std::string& group_name, + const scoped_refptr<typename PoolType::SocketParams>& params, + int num_sockets, + const BoundNetLog& net_log) { pool->RequestSockets(group_name, ¶ms, num_sockets, net_log); } diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc index 3332e04a171..cec7956a0ee 100644 --- a/chromium/net/socket/client_socket_pool_base.cc +++ b/chromium/net/socket/client_socket_pool_base.cc @@ -65,10 +65,12 @@ int CompareEffectiveRequestPriority( ConnectJob::ConnectJob(const std::string& group_name, base::TimeDelta timeout_duration, + RequestPriority priority, Delegate* delegate, const BoundNetLog& net_log) : group_name_(group_name), timeout_duration_(timeout_duration), + priority_(priority), delegate_(delegate), net_log_(net_log), idle_(true) { @@ -82,6 +84,10 @@ ConnectJob::~ConnectJob() { net_log().EndEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB); } +scoped_ptr<StreamSocket> ConnectJob::PassSocket() { + return socket_.Pass(); +} + int ConnectJob::Connect() { if (timeout_duration_ != base::TimeDelta()) timer_.Start(FROM_HERE, timeout_duration_, this, &ConnectJob::OnTimeout); @@ -100,16 +106,16 @@ int ConnectJob::Connect() { return rv; } -void ConnectJob::set_socket(StreamSocket* socket) { +void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) { if (socket) { net_log().AddEvent(NetLog::TYPE_CONNECT_JOB_SET_SOCKET, socket->NetLog().source().ToEventParametersCallback()); } - socket_.reset(socket); + socket_ = socket.Pass(); } void ConnectJob::NotifyDelegateOfCompletion(int rv) { - // The delegate will delete |this|. + // The delegate will own |this|. Delegate* delegate = delegate_; delegate_ = NULL; @@ -135,7 +141,7 @@ void ConnectJob::LogConnectCompletion(int net_error) { void ConnectJob::OnTimeout() { // Make sure the socket is NULL before calling into |delegate|. - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); net_log_.AddEvent(NetLog::TYPE_SOCKET_POOL_CONNECT_JOB_TIMED_OUT); @@ -161,6 +167,7 @@ ClientSocketPoolBaseHelper::Request::Request( ClientSocketPoolBaseHelper::Request::~Request() {} ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper( + HigherLayeredPool* pool, int max_sockets, int max_sockets_per_group, base::TimeDelta unused_idle_socket_timeout, @@ -177,6 +184,7 @@ ClientSocketPoolBaseHelper::ClientSocketPoolBaseHelper( connect_job_factory_(connect_job_factory), connect_backup_jobs_enabled_(false), pool_generation_number_(0), + pool_(pool), weak_factory_(this) { DCHECK_LE(0, max_sockets_per_group); DCHECK_LE(max_sockets_per_group, max_sockets); @@ -192,9 +200,16 @@ ClientSocketPoolBaseHelper::~ClientSocketPoolBaseHelper() { DCHECK(group_map_.empty()); DCHECK(pending_callback_map_.empty()); DCHECK_EQ(0, connecting_socket_count_); - CHECK(higher_layer_pools_.empty()); + CHECK(higher_pools_.empty()); NetworkChangeNotifier::RemoveIPAddressObserver(this); + + // Remove from lower layer pools. + for (std::set<LowerLayeredPool*>::iterator it = lower_pools_.begin(); + it != lower_pools_.end(); + ++it) { + (*it)->RemoveHigherLayeredPool(pool_); + } } ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair() @@ -209,46 +224,59 @@ ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair( ClientSocketPoolBaseHelper::CallbackResultPair::~CallbackResultPair() {} -// static -void ClientSocketPoolBaseHelper::InsertRequestIntoQueue( - const Request* r, RequestQueue* pending_requests) { - RequestQueue::iterator it = pending_requests->begin(); - // TODO(mmenke): Should the network stack require requests with - // |ignore_limits| have the highest priority? - while (it != pending_requests->end() && - CompareEffectiveRequestPriority(*r, *(*it)) <= 0) { - ++it; +bool ClientSocketPoolBaseHelper::IsStalled() const { + // If a lower layer pool is stalled, consider |this| stalled as well. + for (std::set<LowerLayeredPool*>::const_iterator it = lower_pools_.begin(); + it != lower_pools_.end(); + ++it) { + if ((*it)->IsStalled()) + return true; + } + + // If fewer than |max_sockets_| are in use, then clearly |this| is not + // stalled. + if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_) + return false; + // So in order to be stalled, |this| must be using at least |max_sockets_| AND + // |this| must have a request that is actually stalled on the global socket + // limit. To find such a request, look for a group that has more requests + // than jobs AND where the number of sockets is less than + // |max_sockets_per_group_|. (If the number of sockets is equal to + // |max_sockets_per_group_|, then the request is stalled on the group limit, + // which does not count.) + for (GroupMap::const_iterator it = group_map_.begin(); + it != group_map_.end(); ++it) { + if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) + return true; } - pending_requests->insert(it, r); + return false; } -// static -const ClientSocketPoolBaseHelper::Request* -ClientSocketPoolBaseHelper::RemoveRequestFromQueue( - const RequestQueue::iterator& it, Group* group) { - const Request* req = *it; - group->mutable_pending_requests()->erase(it); - // If there are no more requests, we kill the backup timer. - if (group->pending_requests().empty()) - group->CleanupBackupJob(); - return req; +void ClientSocketPoolBaseHelper::AddLowerLayeredPool( + LowerLayeredPool* lower_pool) { + DCHECK(pool_); + CHECK(!ContainsKey(lower_pools_, lower_pool)); + lower_pools_.insert(lower_pool); + lower_pool->AddHigherLayeredPool(pool_); } -void ClientSocketPoolBaseHelper::AddLayeredPool(LayeredPool* pool) { - CHECK(pool); - CHECK(!ContainsKey(higher_layer_pools_, pool)); - higher_layer_pools_.insert(pool); +void ClientSocketPoolBaseHelper::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + CHECK(higher_pool); + CHECK(!ContainsKey(higher_pools_, higher_pool)); + higher_pools_.insert(higher_pool); } -void ClientSocketPoolBaseHelper::RemoveLayeredPool(LayeredPool* pool) { - CHECK(pool); - CHECK(ContainsKey(higher_layer_pools_, pool)); - higher_layer_pools_.erase(pool); +void ClientSocketPoolBaseHelper::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + CHECK(higher_pool); + CHECK(ContainsKey(higher_pools_, higher_pool)); + higher_pools_.erase(higher_pool); } int ClientSocketPoolBaseHelper::RequestSocket( const std::string& group_name, - const Request* request) { + scoped_ptr<const Request> request) { CHECK(!request->callback().is_null()); CHECK(request->handle()); @@ -259,13 +287,13 @@ int ClientSocketPoolBaseHelper::RequestSocket( request->net_log().BeginEvent(NetLog::TYPE_SOCKET_POOL); Group* group = GetOrCreateGroup(group_name); - int rv = RequestSocketInternal(group_name, request); + int rv = RequestSocketInternal(group_name, *request); if (rv != ERR_IO_PENDING) { request->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, rv); CHECK(!request->handle()->is_initialized()); - delete request; + request.reset(); } else { - InsertRequestIntoQueue(request, group->mutable_pending_requests()); + group->InsertPendingRequest(request.Pass()); // Have to do this asynchronously, as closing sockets in higher level pools // call back in to |this|, which will cause all sorts of fun and exciting // re-entrancy issues if the socket pool is doing something else at the @@ -309,7 +337,7 @@ void ClientSocketPoolBaseHelper::RequestSockets( for (int num_iterations_left = num_sockets; group->NumActiveSocketSlots() < num_sockets && num_iterations_left > 0 ; num_iterations_left--) { - rv = RequestSocketInternal(group_name, &request); + rv = RequestSocketInternal(group_name, request); if (rv < 0 && rv != ERR_IO_PENDING) { // We're encountering a synchronous error. Give up. if (!ContainsKey(group_map_, group_name)) @@ -336,12 +364,12 @@ void ClientSocketPoolBaseHelper::RequestSockets( int ClientSocketPoolBaseHelper::RequestSocketInternal( const std::string& group_name, - const Request* request) { - ClientSocketHandle* const handle = request->handle(); + const Request& request) { + ClientSocketHandle* const handle = request.handle(); const bool preconnecting = !handle; Group* group = GetOrCreateGroup(group_name); - if (!(request->flags() & NO_IDLE_SOCKETS)) { + if (!(request.flags() & NO_IDLE_SOCKETS)) { // Try to reuse a socket. if (AssignIdleSocketToRequest(request, group)) return OK; @@ -355,17 +383,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // Can we make another active socket now? if (!group->HasAvailableSocketSlot(max_sockets_per_group_) && - !request->ignore_limits()) { + !request.ignore_limits()) { // TODO(willchan): Consider whether or not we need to close a socket in a // higher layered group. I don't think this makes sense since we would just // reuse that socket then if we needed one and wouldn't make it down to this // layer. - request->net_log().AddEvent( + request.net_log().AddEvent( NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS_PER_GROUP); return ERR_IO_PENDING; } - if (ReachedMaxSocketsLimit() && !request->ignore_limits()) { + if (ReachedMaxSocketsLimit() && !request.ignore_limits()) { // NOTE(mmenke): Wonder if we really need different code for each case // here. Only reason for them now seems to be preconnects. if (idle_socket_count() > 0) { @@ -378,7 +406,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( } else { // We could check if we really have a stalled group here, but it requires // a scan of all groups, so just flip a flag here, and do the check later. - request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); + request.net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); return ERR_IO_PENDING; } } @@ -386,17 +414,17 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // We couldn't find a socket to reuse, and there's space to allocate one, // so allocate and connect a new one. scoped_ptr<ConnectJob> connect_job( - connect_job_factory_->NewConnectJob(group_name, *request, this)); + connect_job_factory_->NewConnectJob(group_name, request, this)); int rv = connect_job->Connect(); if (rv == OK) { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); if (!preconnecting) { - HandOutSocket(connect_job->ReleaseSocket(), false /* not reused */, + HandOutSocket(connect_job->PassSocket(), false /* not reused */, connect_job->connect_timing(), handle, base::TimeDelta(), - group, request->net_log()); + group, request.net_log()); } else { - AddIdleSocket(connect_job->ReleaseSocket(), group); + AddIdleSocket(connect_job->PassSocket(), group); } } else if (rv == ERR_IO_PENDING) { // If we don't have any sockets in this group, set a timer for potentially @@ -409,19 +437,19 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( connecting_socket_count_++; - group->AddJob(connect_job.release(), preconnecting); + group->AddJob(connect_job.Pass(), preconnecting); } else { LogBoundConnectJobToRequest(connect_job->net_log().source(), request); - StreamSocket* error_socket = NULL; + scoped_ptr<StreamSocket> error_socket; if (!preconnecting) { DCHECK(handle); connect_job->GetAdditionalErrorState(handle); - error_socket = connect_job->ReleaseSocket(); + error_socket = connect_job->PassSocket(); } if (error_socket) { - HandOutSocket(error_socket, false /* not reused */, + HandOutSocket(error_socket.Pass(), false /* not reused */, connect_job->connect_timing(), handle, base::TimeDelta(), - group, request->net_log()); + group, request.net_log()); } else if (group->IsEmpty()) { RemoveGroup(group_name); } @@ -431,7 +459,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( } bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( - const Request* request, Group* group) { + const Request& request, Group* group) { std::list<IdleSocket>* idle_sockets = group->mutable_idle_sockets(); std::list<IdleSocket>::iterator idle_socket_it = idle_sockets->end(); @@ -469,13 +497,13 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( IdleSocket idle_socket = *idle_socket_it; idle_sockets->erase(idle_socket_it); HandOutSocket( - idle_socket.socket, + scoped_ptr<StreamSocket>(idle_socket.socket), idle_socket.socket->WasEverUsed(), LoadTimingInfo::ConnectTiming(), - request->handle(), + request.handle(), idle_time, group, - request->net_log()); + request.net_log()); return true; } @@ -484,9 +512,9 @@ bool ClientSocketPoolBaseHelper::AssignIdleSocketToRequest( // static void ClientSocketPoolBaseHelper::LogBoundConnectJobToRequest( - const NetLog::Source& connect_job_source, const Request* request) { - request->net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, - connect_job_source.ToEventParametersCallback()); + const NetLog::Source& connect_job_source, const Request& request) { + request.net_log().AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_CONNECT_JOB, + connect_job_source.ToEventParametersCallback()); } void ClientSocketPoolBaseHelper::CancelRequest( @@ -495,11 +523,11 @@ void ClientSocketPoolBaseHelper::CancelRequest( if (callback_it != pending_callback_map_.end()) { int result = callback_it->second.result; pending_callback_map_.erase(callback_it); - StreamSocket* socket = handle->release_socket(); + scoped_ptr<StreamSocket> socket = handle->PassSocket(); if (socket) { if (result != OK) socket->Disconnect(); - ReleaseSocket(handle->group_name(), socket, handle->id()); + ReleaseSocket(handle->group_name(), socket.Pass(), handle->id()); } return; } @@ -509,21 +537,18 @@ void ClientSocketPoolBaseHelper::CancelRequest( Group* group = GetOrCreateGroup(group_name); // Search pending_requests for matching handle. - RequestQueue::iterator it = group->mutable_pending_requests()->begin(); - for (; it != group->pending_requests().end(); ++it) { - if ((*it)->handle() == handle) { - scoped_ptr<const Request> req(RemoveRequestFromQueue(it, group)); - req->net_log().AddEvent(NetLog::TYPE_CANCELLED); - req->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); - - // We let the job run, unless we're at the socket limit and there is - // not another request waiting on the job. - if (group->jobs().size() > group->pending_requests().size() && - ReachedMaxSocketsLimit()) { - RemoveConnectJob(*group->jobs().begin(), group); - CheckForStalledSocketGroups(); - } - break; + scoped_ptr<const Request> request = + group->FindAndRemovePendingRequest(handle); + if (request) { + request->net_log().AddEvent(NetLog::TYPE_CANCELLED); + request->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); + + // We let the job run, unless we're at the socket limit and there is + // not another request waiting on the job. + if (group->jobs().size() > group->pending_request_count() && + ReachedMaxSocketsLimit()) { + RemoveConnectJob(*group->jobs().begin(), group); + CheckForStalledSocketGroups(); } } } @@ -560,16 +585,7 @@ LoadState ClientSocketPoolBaseHelper::GetLoadState( // Can't use operator[] since it is non-const. const Group& group = *group_map_.find(group_name)->second; - // Search the first group.jobs().size() |pending_requests| for |handle|. - // If it's farther back in the deque than that, it doesn't have a - // corresponding ConnectJob. - size_t connect_jobs = group.jobs().size(); - RequestQueue::const_iterator it = group.pending_requests().begin(); - for (size_t i = 0; it != group.pending_requests().end() && i < connect_jobs; - ++it, ++i) { - if ((*it)->handle() != handle) - continue; - + if (group.HasConnectJobForHandle(handle)) { // Just return the state of the farthest along ConnectJob for the first // group.jobs().size() pending requests. LoadState max_state = LOAD_STATE_IDLE; @@ -607,8 +623,8 @@ base::DictionaryValue* ClientSocketPoolBaseHelper::GetInfoAsValue( base::DictionaryValue* group_dict = new base::DictionaryValue(); group_dict->SetInteger("pending_request_count", - group->pending_requests().size()); - if (!group->pending_requests().empty()) { + group->pending_request_count()); + if (group->has_pending_requests()) { group_dict->SetInteger("top_pending_priority", group->TopPendingPriority()); } @@ -756,7 +772,7 @@ void ClientSocketPoolBaseHelper::StartIdleSocketTimer() { } void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) { GroupMap::iterator i = group_map_.find(group_name); CHECK(i != group_map_.end()); @@ -773,10 +789,10 @@ void ClientSocketPoolBaseHelper::ReleaseSocket(const std::string& group_name, id == pool_generation_number_; if (can_reuse) { // Add it to the idle list. - AddIdleSocket(socket, group); + AddIdleSocket(socket.Pass(), group); OnAvailableSocketSlot(group_name, group); } else { - delete socket; + socket.reset(); } CheckForStalledSocketGroups(); @@ -786,8 +802,18 @@ void ClientSocketPoolBaseHelper::CheckForStalledSocketGroups() { // If we have idle sockets, see if we can give one to the top-stalled group. std::string top_group_name; Group* top_group = NULL; - if (!FindTopStalledGroup(&top_group, &top_group_name)) + if (!FindTopStalledGroup(&top_group, &top_group_name)) { + // There may still be a stalled group in a lower level pool. + for (std::set<LowerLayeredPool*>::iterator it = lower_pools_.begin(); + it != lower_pools_.end(); + ++it) { + if ((*it)->IsStalled()) { + CloseOneIdleSocket(); + break; + } + } return; + } if (ReachedMaxSocketsLimit()) { if (idle_socket_count() > 0) { @@ -820,8 +846,7 @@ bool ClientSocketPoolBaseHelper::FindTopStalledGroup( for (GroupMap::const_iterator i = group_map_.begin(); i != group_map_.end(); ++i) { Group* curr_group = i->second; - const RequestQueue& queue = curr_group->pending_requests(); - if (queue.empty()) + if (!curr_group->has_pending_requests()) continue; if (curr_group->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) { if (!group) @@ -854,27 +879,29 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( CHECK(group_it != group_map_.end()); Group* group = group_it->second; - scoped_ptr<StreamSocket> socket(job->ReleaseSocket()); + scoped_ptr<StreamSocket> socket = job->PassSocket(); // Copies of these are needed because |job| may be deleted before they are // accessed. BoundNetLog job_log = job->net_log(); LoadTimingInfo::ConnectTiming connect_timing = job->connect_timing(); + // RemoveConnectJob(job, _) must be called by all branches below; + // otherwise, |job| will be leaked. + if (result == OK) { DCHECK(socket.get()); RemoveConnectJob(job, group); - if (!group->pending_requests().empty()) { - scoped_ptr<const Request> r(RemoveRequestFromQueue( - group->mutable_pending_requests()->begin(), group)); - LogBoundConnectJobToRequest(job_log.source(), r.get()); + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + if (request) { + LogBoundConnectJobToRequest(job_log.source(), *request); HandOutSocket( - socket.release(), false /* unused socket */, connect_timing, - r->handle(), base::TimeDelta(), group, r->net_log()); - r->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); - InvokeUserCallbackLater(r->handle(), r->callback(), result); + socket.Pass(), false /* unused socket */, connect_timing, + request->handle(), base::TimeDelta(), group, request->net_log()); + request->net_log().EndEvent(NetLog::TYPE_SOCKET_POOL); + InvokeUserCallbackLater(request->handle(), request->callback(), result); } else { - AddIdleSocket(socket.release(), group); + AddIdleSocket(socket.Pass(), group); OnAvailableSocketSlot(group_name, group); CheckForStalledSocketGroups(); } @@ -882,20 +909,20 @@ void ClientSocketPoolBaseHelper::OnConnectJobComplete( // If we got a socket, it must contain error information so pass that // up so that the caller can retrieve it. bool handed_out_socket = false; - if (!group->pending_requests().empty()) { - scoped_ptr<const Request> r(RemoveRequestFromQueue( - group->mutable_pending_requests()->begin(), group)); - LogBoundConnectJobToRequest(job_log.source(), r.get()); - job->GetAdditionalErrorState(r->handle()); + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + if (request) { + LogBoundConnectJobToRequest(job_log.source(), *request); + job->GetAdditionalErrorState(request->handle()); RemoveConnectJob(job, group); if (socket.get()) { handed_out_socket = true; - HandOutSocket(socket.release(), false /* unused socket */, - connect_timing, r->handle(), base::TimeDelta(), group, - r->net_log()); + HandOutSocket(socket.Pass(), false /* unused socket */, + connect_timing, request->handle(), base::TimeDelta(), + group, request->net_log()); } - r->net_log().EndEventWithNetErrorCode(NetLog::TYPE_SOCKET_POOL, result); - InvokeUserCallbackLater(r->handle(), r->callback(), result); + request->net_log().EndEventWithNetErrorCode( + NetLog::TYPE_SOCKET_POOL, result); + InvokeUserCallbackLater(request->handle(), request->callback(), result); } else { RemoveConnectJob(job, group); } @@ -917,59 +944,38 @@ void ClientSocketPoolBaseHelper::FlushWithError(int error) { CancelAllRequestsWithError(error); } -bool ClientSocketPoolBaseHelper::IsStalled() const { - // If we are not using |max_sockets_|, then clearly we are not stalled - if ((handed_out_socket_count_ + connecting_socket_count_) < max_sockets_) - return false; - // So in order to be stalled we need to be using |max_sockets_| AND - // we need to have a request that is actually stalled on the global - // socket limit. To find such a request, we look for a group that - // a has more requests that jobs AND where the number of jobs is less - // than |max_sockets_per_group_|. (If the number of jobs is equal to - // |max_sockets_per_group_|, then the request is stalled on the group, - // which does not count.) - for (GroupMap::const_iterator it = group_map_.begin(); - it != group_map_.end(); ++it) { - if (it->second->IsStalledOnPoolMaxSockets(max_sockets_per_group_)) - return true; - } - return false; -} - void ClientSocketPoolBaseHelper::RemoveConnectJob(ConnectJob* job, Group* group) { CHECK_GT(connecting_socket_count_, 0); connecting_socket_count_--; DCHECK(group); - DCHECK(ContainsKey(group->jobs(), job)); group->RemoveJob(job); // If we've got no more jobs for this group, then we no longer need a // backup job either. if (group->jobs().empty()) group->CleanupBackupJob(); - - DCHECK(job); - delete job; } void ClientSocketPoolBaseHelper::OnAvailableSocketSlot( const std::string& group_name, Group* group) { DCHECK(ContainsKey(group_map_, group_name)); - if (group->IsEmpty()) + if (group->IsEmpty()) { RemoveGroup(group_name); - else if (!group->pending_requests().empty()) + } else if (group->has_pending_requests()) { ProcessPendingRequest(group_name, group); + } } void ClientSocketPoolBaseHelper::ProcessPendingRequest( const std::string& group_name, Group* group) { - int rv = RequestSocketInternal(group_name, - *group->pending_requests().begin()); + const Request* next_request = group->GetNextPendingRequest(); + DCHECK(next_request); + int rv = RequestSocketInternal(group_name, *next_request); if (rv != ERR_IO_PENDING) { - scoped_ptr<const Request> request(RemoveRequestFromQueue( - group->mutable_pending_requests()->begin(), group)); + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + DCHECK(request); if (group->IsEmpty()) RemoveGroup(group_name); @@ -979,7 +985,7 @@ void ClientSocketPoolBaseHelper::ProcessPendingRequest( } void ClientSocketPoolBaseHelper::HandOutSocket( - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, bool reused, const LoadTimingInfo::ConnectTiming& connect_timing, ClientSocketHandle* handle, @@ -987,7 +993,7 @@ void ClientSocketPoolBaseHelper::HandOutSocket( Group* group, const BoundNetLog& net_log) { DCHECK(socket); - handle->set_socket(socket); + handle->SetSocket(socket.Pass()); handle->set_is_reused(reused); handle->set_idle_time(idle_time); handle->set_pool_id(pool_generation_number_); @@ -1000,18 +1006,20 @@ void ClientSocketPoolBaseHelper::HandOutSocket( "idle_ms", static_cast<int>(idle_time.InMilliseconds()))); } - net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, - socket->NetLog().source().ToEventParametersCallback()); + net_log.AddEvent( + NetLog::TYPE_SOCKET_POOL_BOUND_TO_SOCKET, + handle->socket()->NetLog().source().ToEventParametersCallback()); handed_out_socket_count_++; group->IncrementActiveSocketCount(); } void ClientSocketPoolBaseHelper::AddIdleSocket( - StreamSocket* socket, Group* group) { + scoped_ptr<StreamSocket> socket, + Group* group) { DCHECK(socket); IdleSocket idle_socket; - idle_socket.socket = socket; + idle_socket.socket = socket.release(); idle_socket.start_time = base::TimeTicks::Now(); group->mutable_idle_sockets()->push_back(idle_socket); @@ -1041,13 +1049,11 @@ void ClientSocketPoolBaseHelper::CancelAllRequestsWithError(int error) { for (GroupMap::iterator i = group_map_.begin(); i != group_map_.end();) { Group* group = i->second; - RequestQueue pending_requests; - pending_requests.swap(*group->mutable_pending_requests()); - for (RequestQueue::iterator it2 = pending_requests.begin(); - it2 != pending_requests.end(); ++it2) { - scoped_ptr<const Request> request(*it2); - InvokeUserCallbackLater( - request->handle(), request->callback(), error); + while (true) { + scoped_ptr<const Request> request = group->PopNextPendingRequest(); + if (!request) + break; + InvokeUserCallbackLater(request->handle(), request->callback(), error); } // Delete group if no longer needed. @@ -1103,12 +1109,12 @@ bool ClientSocketPoolBaseHelper::CloseOneIdleSocketExceptInGroup( return false; } -bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInLayeredPool() { +bool ClientSocketPoolBaseHelper::CloseOneIdleConnectionInHigherLayeredPool() { // This pool doesn't have any idle sockets. It's possible that a pool at a // higher layer is holding one of this sockets active, but it's actually idle. // Query the higher layers. - for (std::set<LayeredPool*>::const_iterator it = higher_layer_pools_.begin(); - it != higher_layer_pools_.end(); ++it) { + for (std::set<HigherLayeredPool*>::const_iterator it = higher_pools_.begin(); + it != higher_pools_.end(); ++it) { if ((*it)->CloseOneIdleConnection()) return true; } @@ -1144,7 +1150,7 @@ void ClientSocketPoolBaseHelper::TryToCloseSocketsInLayeredPools() { while (IsStalled()) { // Closing a socket will result in calling back into |this| to use the freed // socket slot, so nothing else is needed. - if (!CloseOneIdleConnectionInLayeredPool()) + if (!CloseOneIdleConnectionInHigherLayeredPool()) return; } } @@ -1182,19 +1188,25 @@ bool ClientSocketPoolBaseHelper::Group::TryToUseUnassignedConnectJob() { return true; } -void ClientSocketPoolBaseHelper::Group::AddJob(ConnectJob* job, +void ClientSocketPoolBaseHelper::Group::AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect) { SanityCheck(); if (is_preconnect) ++unassigned_job_count_; - jobs_.insert(job); + jobs_.insert(job.release()); } void ClientSocketPoolBaseHelper::Group::RemoveJob(ConnectJob* job) { + scoped_ptr<ConnectJob> owned_job(job); SanityCheck(); - jobs_.erase(job); + std::set<ConnectJob*>::iterator it = jobs_.find(job); + if (it != jobs_.end()) { + jobs_.erase(it); + } else { + NOTREACHED(); + } size_t job_count = jobs_.size(); if (job_count < unassigned_job_count_) unassigned_job_count_ = job_count; @@ -1222,15 +1234,17 @@ void ClientSocketPoolBaseHelper::Group::OnBackupSocketTimerFired( if (pending_requests_.empty()) return; - ConnectJob* backup_job = pool->connect_job_factory_->NewConnectJob( - group_name, **pending_requests_.begin(), pool); + scoped_ptr<ConnectJob> backup_job = + pool->connect_job_factory_->NewConnectJob( + group_name, **pending_requests_.begin(), pool); backup_job->net_log().AddEvent(NetLog::TYPE_SOCKET_BACKUP_CREATED); SIMPLE_STATS_COUNTER("socket.backup_created"); int rv = backup_job->Connect(); pool->connecting_socket_count_++; - AddJob(backup_job, false); + ConnectJob* raw_backup_job = backup_job.get(); + AddJob(backup_job.Pass(), false); if (rv != ERR_IO_PENDING) - pool->OnConnectJobComplete(rv, backup_job); + pool->OnConnectJobComplete(rv, raw_backup_job); } void ClientSocketPoolBaseHelper::Group::SanityCheck() { @@ -1248,6 +1262,68 @@ void ClientSocketPoolBaseHelper::Group::RemoveAllJobs() { weak_factory_.InvalidateWeakPtrs(); } +const ClientSocketPoolBaseHelper::Request* +ClientSocketPoolBaseHelper::Group::GetNextPendingRequest() const { + return pending_requests_.empty() ? NULL : *pending_requests_.begin(); +} + +bool ClientSocketPoolBaseHelper::Group::HasConnectJobForHandle( + const ClientSocketHandle* handle) const { + // Search the first |jobs_.size()| pending requests for |handle|. + // If it's farther back in the deque than that, it doesn't have a + // corresponding ConnectJob. + size_t i = 0; + for (RequestQueue::const_iterator it = pending_requests_.begin(); + it != pending_requests_.end() && i < jobs_.size(); ++it, ++i) { + if ((*it)->handle() == handle) + return true; + } + return false; +} + +void ClientSocketPoolBaseHelper::Group::InsertPendingRequest( + scoped_ptr<const Request> r) { + RequestQueue::iterator it = pending_requests_.begin(); + // TODO(mmenke): Should the network stack require requests with + // |ignore_limits| have the highest priority? + while (it != pending_requests_.end() && + CompareEffectiveRequestPriority(*r, *(*it)) <= 0) { + ++it; + } + pending_requests_.insert(it, r.release()); +} + +scoped_ptr<const ClientSocketPoolBaseHelper::Request> +ClientSocketPoolBaseHelper::Group::PopNextPendingRequest() { + if (pending_requests_.empty()) + return scoped_ptr<const ClientSocketPoolBaseHelper::Request>(); + return RemovePendingRequest(pending_requests_.begin()); +} + +scoped_ptr<const ClientSocketPoolBaseHelper::Request> +ClientSocketPoolBaseHelper::Group::FindAndRemovePendingRequest( + ClientSocketHandle* handle) { + for (RequestQueue::iterator it = pending_requests_.begin(); + it != pending_requests_.end(); ++it) { + if ((*it)->handle() == handle) { + scoped_ptr<const Request> request = RemovePendingRequest(it); + return request.Pass(); + } + } + return scoped_ptr<const ClientSocketPoolBaseHelper::Request>(); +} + +scoped_ptr<const ClientSocketPoolBaseHelper::Request> +ClientSocketPoolBaseHelper::Group::RemovePendingRequest( + const RequestQueue::iterator& it) { + scoped_ptr<const Request> request(*it); + pending_requests_.erase(it); + // If there are no more requests, kill the backup timer. + if (pending_requests_.empty()) + CleanupBackupJob(); + return request.Pass(); +} + } // namespace internal } // namespace net diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h index 4bf95d7b04a..31ec9bf7b13 100644 --- a/chromium/net/socket/client_socket_pool_base.h +++ b/chromium/net/socket/client_socket_pool_base.h @@ -61,8 +61,11 @@ class NET_EXPORT_PRIVATE ConnectJob { Delegate() {} virtual ~Delegate() {} - // Alerts the delegate that the connection completed. - virtual void OnConnectJobComplete(int result, ConnectJob* job) = 0; + // Alerts the delegate that the connection completed. |job| must + // be destroyed by the delegate. A scoped_ptr<> isn't used because + // the caller of this function doesn't own |job|. + virtual void OnConnectJobComplete(int result, + ConnectJob* job) = 0; private: DISALLOW_COPY_AND_ASSIGN(Delegate); @@ -71,6 +74,7 @@ class NET_EXPORT_PRIVATE ConnectJob { // A |timeout_duration| of 0 corresponds to no timeout. ConnectJob(const std::string& group_name, base::TimeDelta timeout_duration, + RequestPriority priority, Delegate* delegate, const BoundNetLog& net_log); virtual ~ConnectJob(); @@ -79,9 +83,10 @@ class NET_EXPORT_PRIVATE ConnectJob { const std::string& group_name() const { return group_name_; } const BoundNetLog& net_log() { return net_log_; } - // Releases |socket_| to the client. On connection error, this should return - // NULL. - StreamSocket* ReleaseSocket() { return socket_.release(); } + // Releases ownership of the underlying socket to the caller. + // Returns the released socket, or NULL if there was a connection + // error. + scoped_ptr<StreamSocket> PassSocket(); // Begins connecting the socket. Returns OK on success, ERR_IO_PENDING if it // cannot complete synchronously without blocking, or another net error code @@ -105,7 +110,8 @@ class NET_EXPORT_PRIVATE ConnectJob { const BoundNetLog& net_log() const { return net_log_; } protected: - void set_socket(StreamSocket* socket); + RequestPriority priority() const { return priority_; } + void SetSocket(scoped_ptr<StreamSocket> socket); StreamSocket* socket() { return socket_.get(); } void NotifyDelegateOfCompletion(int rv); void ResetTimer(base::TimeDelta remainingTime); @@ -124,6 +130,8 @@ class NET_EXPORT_PRIVATE ConnectJob { const std::string group_name_; const base::TimeDelta timeout_duration_; + // TODO(akalin): Support reprioritization. + const RequestPriority priority_; // Timer to abort jobs that take too long. base::OneShotTimer<ConnectJob> timer_; Delegate* delegate_; @@ -175,6 +183,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper private: ClientSocketHandle* const handle_; CompletionCallback callback_; + // TODO(akalin): Support reprioritization. const RequestPriority priority_; bool ignore_limits_; const Flags flags_; @@ -188,7 +197,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper ConnectJobFactory() {} virtual ~ConnectJobFactory() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const Request& request, ConnectJob::Delegate* delegate) const = 0; @@ -200,6 +209,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper }; ClientSocketPoolBaseHelper( + HigherLayeredPool* pool, int max_sockets, int max_sockets_per_group, base::TimeDelta unused_idle_socket_timeout, @@ -208,15 +218,21 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper virtual ~ClientSocketPoolBaseHelper(); - // Adds/Removes layered pools. It is expected in the destructor that no - // layered pools remain. - void AddLayeredPool(LayeredPool* pool); - void RemoveLayeredPool(LayeredPool* pool); + // Adds a lower layered pool to |this|, and adds |this| as a higher layered + // pool on top of |lower_pool|. + void AddLowerLayeredPool(LowerLayeredPool* lower_pool); + + // See LowerLayeredPool::IsStalled for documentation on this function. + bool IsStalled() const; + + // See LowerLayeredPool for documentation on these functions. It is expected + // in the destructor that no higher layer pools remain. + void AddHigherLayeredPool(HigherLayeredPool* higher_pool); + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool); // See ClientSocketPool::RequestSocket for documentation on this function. - // ClientSocketPoolBaseHelper takes ownership of |request|, which must be - // heap allocated. - int RequestSocket(const std::string& group_name, const Request* request); + int RequestSocket(const std::string& group_name, + scoped_ptr<const Request> request); // See ClientSocketPool::RequestSocket for documentation on this function. void RequestSockets(const std::string& group_name, @@ -229,15 +245,12 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // See ClientSocketPool::ReleaseSocket for documentation on this function. void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id); // See ClientSocketPool::FlushWithError for documentation on this function. void FlushWithError(int error); - // See ClientSocketPool::IsStalled for documentation on this function. - bool IsStalled() const; - // See ClientSocketPool::CloseIdleSockets for documentation on this function. void CloseIdleSockets(); @@ -294,8 +307,8 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // I'm not sure if we hit this situation often. bool CloseOneIdleSocket(); - // Checks layered pools to see if they can close an idle connection. - bool CloseOneIdleConnectionInLayeredPool(); + // Checks higher layered pools to see if they can close an idle connection. + bool CloseOneIdleConnectionInHigherLayeredPool(); // See ClientSocketPool::GetInfoAsValue for documentation on this function. base::DictionaryValue* GetInfoAsValue(const std::string& name, @@ -386,22 +399,55 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // Otherwise, returns false. bool TryToUseUnassignedConnectJob(); - void AddJob(ConnectJob* job, bool is_preconnect); + void AddJob(scoped_ptr<ConnectJob> job, bool is_preconnect); + // Remove |job| from this group, which must already own |job|. void RemoveJob(ConnectJob* job); void RemoveAllJobs(); + bool has_pending_requests() const { + return !pending_requests_.empty(); + } + + size_t pending_request_count() const { + return pending_requests_.size(); + } + + // Gets (but does not remove) the next pending request. Returns + // NULL if there are no pending requests. + const Request* GetNextPendingRequest() const; + + // Returns true if there is a connect job for |handle|. + bool HasConnectJobForHandle(const ClientSocketHandle* handle) const; + + // Inserts the request into the queue based on priority + // order. Older requests are prioritized over requests of equal + // priority. + void InsertPendingRequest(scoped_ptr<const Request> r); + + // Gets and removes the next pending request. Returns NULL if + // there are no pending requests. + scoped_ptr<const Request> PopNextPendingRequest(); + + // Finds the pending request for |handle| and removes it. Returns + // the removed pending request, or NULL if there was none. + scoped_ptr<const Request> FindAndRemovePendingRequest( + ClientSocketHandle* handle); + void IncrementActiveSocketCount() { active_socket_count_++; } void DecrementActiveSocketCount() { active_socket_count_--; } int unassigned_job_count() const { return unassigned_job_count_; } const std::set<ConnectJob*>& jobs() const { return jobs_; } const std::list<IdleSocket>& idle_sockets() const { return idle_sockets_; } - const RequestQueue& pending_requests() const { return pending_requests_; } int active_socket_count() const { return active_socket_count_; } - RequestQueue* mutable_pending_requests() { return &pending_requests_; } std::list<IdleSocket>* mutable_idle_sockets() { return &idle_sockets_; } private: + // Returns the iterator's pending request after removing it from + // the queue. + scoped_ptr<const Request> RemovePendingRequest( + const RequestQueue::iterator& it); + // Called when the backup socket timer fires. void OnBackupSocketTimerFired( std::string group_name, @@ -443,15 +489,6 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper typedef std::map<const ClientSocketHandle*, CallbackResultPair> PendingCallbackMap; - // Inserts the request into the queue based on order they will receive - // sockets. Sockets which ignore the socket pool limits are first. Then - // requests are sorted by priority, with higher priorities closer to the - // front. Older requests are prioritized over requests of equal priority. - static void InsertRequestIntoQueue(const Request* r, - RequestQueue* pending_requests); - static const Request* RemoveRequestFromQueue(const RequestQueue::iterator& it, - Group* group); - Group* GetOrCreateGroup(const std::string& group_name); void RemoveGroup(const std::string& group_name); void RemoveGroup(GroupMap::iterator it); @@ -475,7 +512,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper CleanupIdleSockets(false); } - // Removes |job| from |connect_job_set_|. Also updates |group| if non-NULL. + // Removes |job| from |group|, which must already own |job|. void RemoveConnectJob(ConnectJob* job, Group* group); // Tries to see if we can handle any more requests for |group|. @@ -485,7 +522,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper void ProcessPendingRequest(const std::string& group_name, Group* group); // Assigns |socket| to |handle| and updates |group|'s counters appropriately. - void HandOutSocket(StreamSocket* socket, + void HandOutSocket(scoped_ptr<StreamSocket> socket, bool reused, const LoadTimingInfo::ConnectTiming& connect_timing, ClientSocketHandle* handle, @@ -494,7 +531,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper const BoundNetLog& net_log); // Adds |socket| to the list of idle sockets for |group|. - void AddIdleSocket(StreamSocket* socket, Group* group); + void AddIdleSocket(scoped_ptr<StreamSocket> socket, Group* group); // Iterates through |group_map_|, canceling all ConnectJobs and deleting // groups if they are no longer needed. @@ -511,14 +548,14 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // it does not handle logging into NetLog of the queueing status of // |request|. int RequestSocketInternal(const std::string& group_name, - const Request* request); + const Request& request); // Assigns an idle socket for the group to the request. // Returns |true| if an idle socket is available, false otherwise. - bool AssignIdleSocketToRequest(const Request* request, Group* group); + bool AssignIdleSocketToRequest(const Request& request, Group* group); static void LogBoundConnectJobToRequest( - const NetLog::Source& connect_job_source, const Request* request); + const NetLog::Source& connect_job_source, const Request& request); // Same as CloseOneIdleSocket() except it won't close an idle socket in // |group|. If |group| is NULL, it is ignored. Returns true if it closed a @@ -588,7 +625,18 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // to the pool, we can make sure that they are discarded rather than reused. int pool_generation_number_; - std::set<LayeredPool*> higher_layer_pools_; + // Used to add |this| as a higher layer pool on top of lower layer pools. May + // be NULL if no lower layer pools will be added. + HigherLayeredPool* pool_; + + // Pools that create connections through |this|. |this| will try to close + // their idle sockets when it stalls. Must be empty on destruction. + std::set<HigherLayeredPool*> higher_pools_; + + // Pools that this goes through. Typically there's only one, but not always. + // |this| will check if they're stalled when it has a new idle socket. |this| + // will remove itself from all lower layered pools on destruction. + std::set<LowerLayeredPool*> lower_pools_; base::WeakPtrFactory<ClientSocketPoolBaseHelper> weak_factory_; @@ -624,7 +672,7 @@ class ClientSocketPoolBase { ConnectJobFactory() {} virtual ~ConnectJobFactory() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const Request& request, ConnectJob::Delegate* delegate) const = 0; @@ -642,6 +690,7 @@ class ClientSocketPoolBase { // |used_idle_socket_timeout| specifies how long to leave a previously used // idle socket open before closing it. ClientSocketPoolBase( + HigherLayeredPool* self, int max_sockets, int max_sockets_per_group, ClientSocketPoolHistograms* histograms, @@ -649,19 +698,23 @@ class ClientSocketPoolBase { base::TimeDelta used_idle_socket_timeout, ConnectJobFactory* connect_job_factory) : histograms_(histograms), - helper_(max_sockets, max_sockets_per_group, + helper_(self, max_sockets, max_sockets_per_group, unused_idle_socket_timeout, used_idle_socket_timeout, new ConnectJobFactoryAdaptor(connect_job_factory)) {} virtual ~ClientSocketPoolBase() {} // These member functions simply forward to ClientSocketPoolBaseHelper. - void AddLayeredPool(LayeredPool* pool) { - helper_.AddLayeredPool(pool); + void AddLowerLayeredPool(LowerLayeredPool* lower_pool) { + helper_.AddLowerLayeredPool(lower_pool); } - void RemoveLayeredPool(LayeredPool* pool) { - helper_.RemoveLayeredPool(pool); + void AddHigherLayeredPool(HigherLayeredPool* higher_pool) { + helper_.AddHigherLayeredPool(higher_pool); + } + + void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) { + helper_.RemoveHigherLayeredPool(higher_pool); } // RequestSocket bundles up the parameters into a Request and then forwards to @@ -672,12 +725,15 @@ class ClientSocketPoolBase { ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { - Request* request = + scoped_ptr<const Request> request( new Request(handle, callback, priority, internal::ClientSocketPoolBaseHelper::NORMAL, params->ignore_limits(), - params, net_log); - return helper_.RequestSocket(group_name, request); + params, net_log)); + return helper_.RequestSocket( + group_name, + request.template PassAs< + const internal::ClientSocketPoolBaseHelper::Request>()); } // RequestSockets bundles up the parameters into a Request and then forwards @@ -702,9 +758,10 @@ class ClientSocketPoolBase { return helper_.CancelRequest(group_name, handle); } - void ReleaseSocket(const std::string& group_name, StreamSocket* socket, + void ReleaseSocket(const std::string& group_name, + scoped_ptr<StreamSocket> socket, int id) { - return helper_.ReleaseSocket(group_name, socket, id); + return helper_.ReleaseSocket(group_name, socket.Pass(), id); } void FlushWithError(int error) { helper_.FlushWithError(error); } @@ -765,8 +822,8 @@ class ClientSocketPoolBase { bool CloseOneIdleSocket() { return helper_.CloseOneIdleSocket(); } - bool CloseOneIdleConnectionInLayeredPool() { - return helper_.CloseOneIdleConnectionInLayeredPool(); + bool CloseOneIdleConnectionInHigherLayeredPool() { + return helper_.CloseOneIdleConnectionInHigherLayeredPool(); } private: @@ -785,13 +842,13 @@ class ClientSocketPoolBase { : connect_job_factory_(connect_job_factory) {} virtual ~ConnectJobFactoryAdaptor() {} - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const internal::ClientSocketPoolBaseHelper::Request& request, - ConnectJob::Delegate* delegate) const { - const Request* casted_request = static_cast<const Request*>(&request); + ConnectJob::Delegate* delegate) const OVERRIDE { + const Request& casted_request = static_cast<const Request&>(request); return connect_job_factory_->NewConnectJob( - group_name, *casted_request, delegate); + group_name, casted_request, delegate); } virtual base::TimeDelta ConnectionTimeout() const { diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc index 5eeda972cff..bbeca2f3e11 100644 --- a/chromium/net/socket/client_socket_pool_base_unittest.cc +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -30,7 +30,9 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/udp/datagram_client_socket.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" @@ -189,30 +191,30 @@ class MockClientSocketFactory : public ClientSocketFactory { public: MockClientSocketFactory() : allocation_count_(0) {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { NOTREACHED(); - return NULL; + return scoped_ptr<DatagramClientSocket>(); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, const NetLog::Source& /*source*/) OVERRIDE { allocation_count_++; - return NULL; + return scoped_ptr<StreamSocket>(); } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -259,7 +261,7 @@ class TestConnectJob : public ConnectJob { ConnectJob::Delegate* delegate, MockClientSocketFactory* client_socket_factory, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, request.priority(), delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), job_type_(job_type), client_socket_factory_(client_socket_factory), @@ -294,7 +296,8 @@ class TestConnectJob : public ConnectJob { AddressList ignored; client_socket_factory_->CreateTransportClientSocket( ignored, NULL, net::NetLog::Source()); - set_socket(new MockClientSocket(net_log().net_log())); + SetSocket( + scoped_ptr<StreamSocket>(new MockClientSocket(net_log().net_log()))); switch (job_type_) { case kMockJob: return DoConnect(true /* successful */, false /* sync */, @@ -373,7 +376,7 @@ class TestConnectJob : public ConnectJob { return ERR_IO_PENDING; default: NOTREACHED(); - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); return ERR_FAILED; } } @@ -386,7 +389,7 @@ class TestConnectJob : public ConnectJob { result = ERR_PROXY_AUTH_REQUESTED; } else { result = ERR_CONNECTION_FAILED; - set_socket(NULL); + SetSocket(scoped_ptr<StreamSocket>()); } if (was_async) @@ -430,7 +433,7 @@ class TestConnectJobFactory // ConnectJobFactory implementation. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const TestClientSocketPoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE { @@ -440,13 +443,13 @@ class TestConnectJobFactory job_type = job_types_->front(); job_types_->pop_front(); } - return new TestConnectJob(job_type, - group_name, - request, - timeout_duration_, - delegate, - client_socket_factory_, - net_log_); + return scoped_ptr<ConnectJob>(new TestConnectJob(job_type, + group_name, + request, + timeout_duration_, + delegate, + client_socket_factory_, + net_log_)); } virtual base::TimeDelta ConnectionTimeout() const OVERRIDE { @@ -465,6 +468,8 @@ class TestConnectJobFactory class TestClientSocketPool : public ClientSocketPool { public: + typedef TestSocketParams SocketParams; + TestClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -472,7 +477,7 @@ class TestClientSocketPool : public ClientSocketPool { base::TimeDelta unused_idle_socket_timeout, base::TimeDelta used_idle_socket_timeout, TestClientSocketPoolBase::ConnectJobFactory* connect_job_factory) - : base_(max_sockets, max_sockets_per_group, histograms, + : base_(NULL, max_sockets, max_sockets_per_group, histograms, unused_idle_socket_timeout, used_idle_socket_timeout, connect_job_factory) {} @@ -509,9 +514,9 @@ class TestClientSocketPool : public ClientSocketPool { virtual void ReleaseSocket( const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE { - base_.ReleaseSocket(group_name, socket, id); + base_.ReleaseSocket(group_name, socket.Pass(), id); } virtual void FlushWithError(int error) OVERRIDE { @@ -541,12 +546,13 @@ class TestClientSocketPool : public ClientSocketPool { return base_.GetLoadState(group_name, handle); } - virtual void AddLayeredPool(LayeredPool* pool) OVERRIDE { - base_.AddLayeredPool(pool); + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE { + base_.AddHigherLayeredPool(higher_pool); } - virtual void RemoveLayeredPool(LayeredPool* pool) OVERRIDE { - base_.RemoveLayeredPool(pool); + virtual void RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) OVERRIDE { + base_.RemoveHigherLayeredPool(higher_pool); } virtual base::DictionaryValue* GetInfoAsValue( @@ -586,8 +592,8 @@ class TestClientSocketPool : public ClientSocketPool { void EnableConnectBackupJobs() { base_.EnableConnectBackupJobs(); } - bool CloseOneIdleConnectionInLayeredPool() { - return base_.CloseOneIdleConnectionInLayeredPool(); + bool CloseOneIdleConnectionInHigherLayeredPool() { + return base_.CloseOneIdleConnectionInHigherLayeredPool(); } private: @@ -598,8 +604,6 @@ class TestClientSocketPool : public ClientSocketPool { } // namespace -REGISTER_SOCKET_PARAMS_FOR_POOL(TestClientSocketPool, TestSocketParams); - namespace { void MockClientSocketFactory::SignalJobs() { @@ -630,10 +634,10 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { virtual void OnConnectJobComplete(int result, ConnectJob* job) OVERRIDE { result_ = result; - scoped_ptr<StreamSocket> socket(job->ReleaseSocket()); + scoped_ptr<ConnectJob> owned_job(job); + scoped_ptr<StreamSocket> socket = owned_job->PassSocket(); // socket.get() should be NULL iff result != OK - EXPECT_EQ(socket.get() == NULL, result != OK); - delete job; + EXPECT_EQ(socket == NULL, result != OK); have_result_ = true; if (waiting_for_result_) base::MessageLoop::current()->Quit(); @@ -702,9 +706,8 @@ class ClientSocketPoolBaseTest : public testing::Test { const std::string& group_name, RequestPriority priority, const scoped_refptr<TestSocketParams>& params) { - return test_base_.StartRequestUsingPool< - TestClientSocketPool, TestSocketParams>( - pool_.get(), group_name, priority, params); + return test_base_.StartRequestUsingPool( + pool_.get(), group_name, priority, params); } int StartRequest(const std::string& group_name, RequestPriority priority) { @@ -3716,7 +3719,7 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) { EXPECT_EQ(1, pool_->NumActiveSocketsInGroup("a")); } -class MockLayeredPool : public LayeredPool { +class MockLayeredPool : public HigherLayeredPool { public: MockLayeredPool(TestClientSocketPool* pool, const std::string& group_name) @@ -3724,11 +3727,11 @@ class MockLayeredPool : public LayeredPool { params_(new TestSocketParams), group_name_(group_name), can_release_connection_(true) { - pool_->AddLayeredPool(this); + pool_->AddHigherLayeredPool(this); } ~MockLayeredPool() { - pool_->RemoveLayeredPool(this); + pool_->RemoveHigherLayeredPool(this); } int RequestSocket(TestClientSocketPool* pool) { @@ -3774,7 +3777,7 @@ TEST_F(ClientSocketPoolBaseTest, FailToCloseIdleSocketsNotHeldByLayeredPool) { EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) .WillOnce(Return(false)); - EXPECT_FALSE(pool_->CloseOneIdleConnectionInLayeredPool()); + EXPECT_FALSE(pool_->CloseOneIdleConnectionInHigherLayeredPool()); } TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) { @@ -3786,7 +3789,7 @@ TEST_F(ClientSocketPoolBaseTest, ForciblyCloseIdleSocketsHeldByLayeredPool) { EXPECT_CALL(mock_layered_pool, CloseOneIdleConnection()) .WillOnce(Invoke(&mock_layered_pool, &MockLayeredPool::ReleaseOneConnection)); - EXPECT_TRUE(pool_->CloseOneIdleConnectionInLayeredPool()); + EXPECT_TRUE(pool_->CloseOneIdleConnectionInHigherLayeredPool()); } // Tests the basic case of closing an idle socket in a higher layered pool when diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc index 71496d28646..b37d2d1949c 100644 --- a/chromium/net/socket/client_socket_pool_manager.cc +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -158,7 +158,6 @@ int InitSocketPoolHelper(const GURL& request_url, bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0; if (proxy_info.is_direct()) { tcp_params = new TransportSocketParams(origin_host_port, - request_priority, disable_resolver_cache, ignore_limits, resolution_callback); @@ -167,7 +166,6 @@ int InitSocketPoolHelper(const GURL& request_url, proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair())); scoped_refptr<TransportSocketParams> proxy_tcp_params( new TransportSocketParams(*proxy_host_port, - request_priority, disable_resolver_cache, ignore_limits, resolution_callback)); @@ -182,7 +180,6 @@ int InitSocketPoolHelper(const GURL& request_url, ssl_params = new SSLSocketParams(proxy_tcp_params, NULL, NULL, - ProxyServer::SCHEME_DIRECT, *proxy_host_port.get(), ssl_config_for_proxy, kPrivacyModeDisabled, @@ -214,8 +211,7 @@ int InitSocketPoolHelper(const GURL& request_url, socks_params = new SOCKSSocketParams(proxy_tcp_params, socks_version == '5', - origin_host_port, - request_priority); + origin_host_port); } } @@ -229,7 +225,6 @@ int InitSocketPoolHelper(const GURL& request_url, new SSLSocketParams(tcp_params, socks_params, http_proxy_params, - proxy_info.proxy_server().scheme(), origin_host_port, ssl_config_for_origin, privacy_mode, diff --git a/chromium/net/socket/deterministic_socket_data_unittest.cc b/chromium/net/socket/deterministic_socket_data_unittest.cc index eba01b5e9cc..c51427e25a7 100644 --- a/chromium/net/socket/deterministic_socket_data_unittest.cc +++ b/chromium/net/socket/deterministic_socket_data_unittest.cc @@ -72,7 +72,6 @@ DeterministicSocketDataTest::DeterministicSocketDataTest() connect_data_(SYNCHRONOUS, OK), endpoint_("www.google.com", 443), tcp_params_(new TransportSocketParams(endpoint_, - LOWEST, false, false, OnHostResolutionCallback())), diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc index be33ac5add0..7e3aee430c4 100644 --- a/chromium/net/socket/nss_ssl_util.cc +++ b/chromium/net/socket/nss_ssl_util.cc @@ -58,12 +58,13 @@ class NSSSSLInitSingleton { enabled = false; // Trim the list of cipher suites in order to keep the size of the - // ClientHello down. DSS, ECDH, CAMELLIA, SEED and ECC+3DES cipher - // suites are disabled. + // ClientHello down. DSS, ECDH, CAMELLIA, SEED, ECC+3DES, and + // HMAC-SHA256 cipher suites are disabled. if (info.symCipher == ssl_calg_camellia || info.symCipher == ssl_calg_seed || (info.symCipher == ssl_calg_3des && info.keaType != ssl_kea_rsa) || info.authAlgorithm == ssl_auth_dsa || + info.macAlgorithm == ssl_hmac_sha256 || info.nonStandard || strcmp(info.keaTypeName, "ECDH") == 0) { enabled = false; @@ -232,6 +233,10 @@ int MapNSSError(PRErrorCode err) { case SEC_ERROR_BAD_DER: case SEC_ERROR_EXTRA_INPUT: return ERR_SSL_BAD_PEER_PUBLIC_KEY; + // During renegotiation, the server presented a different certificate than + // was used earlier. + case SSL_ERROR_WRONG_CERTIFICATE: + return ERR_SSL_SERVER_CERT_CHANGED; default: { if (IS_SSL_ERROR(err)) { diff --git a/chromium/net/socket/socket_descriptor.cc b/chromium/net/socket/socket_descriptor.cc new file mode 100644 index 00000000000..5a2e53cab4d --- /dev/null +++ b/chromium/net/socket/socket_descriptor.cc @@ -0,0 +1,48 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/socket_descriptor.h" + +#if defined(OS_POSIX) +#include <sys/types.h> +#include <sys/socket.h> +#endif + +#include "base/basictypes.h" + +#if defined(OS_WIN) +#include "net/base/winsock_init.h" +#endif + +namespace net { + +PlatformSocketFactory* g_socket_factory = NULL; + +PlatformSocketFactory::PlatformSocketFactory() { +} + +PlatformSocketFactory::~PlatformSocketFactory() { +} + +void PlatformSocketFactory::SetInstance(PlatformSocketFactory* factory) { + g_socket_factory = factory; +} + +SocketDescriptor CreateSocketDefault(int family, int type, int protocol) { +#if defined(OS_WIN) + EnsureWinsockInit(); + return ::WSASocket(family, type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED); +#else // OS_WIN + return ::socket(family, type, protocol); +#endif // OS_WIN +} + +SocketDescriptor CreatePlatformSocket(int family, int type, int protocol) { + if (g_socket_factory) + return g_socket_factory->CreateSocket(family, type, protocol); + else + return CreateSocketDefault(family, type, protocol); +} + +} // namespace net diff --git a/chromium/net/socket/socket_descriptor.h b/chromium/net/socket/socket_descriptor.h new file mode 100644 index 00000000000..b2a22234b80 --- /dev/null +++ b/chromium/net/socket/socket_descriptor.h @@ -0,0 +1,49 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_SOCKET_DESCRIPTOR_H_ +#define NET_SOCKET_SOCKET_DESCRIPTOR_H_ + +#include "build/build_config.h" +#include "net/base/net_export.h" + +#if defined(OS_WIN) +#include <winsock2.h> +#endif // OS_WIN + +namespace net { + +#if defined(OS_POSIX) +typedef int SocketDescriptor; +const SocketDescriptor kInvalidSocket = -1; +#elif defined(OS_WIN) +typedef SOCKET SocketDescriptor; +const SocketDescriptor kInvalidSocket = INVALID_SOCKET; +#endif + +// Interface to create native socket. +// Usually such factories are used for testing purposes, which is not true in +// this case. This interface is used to substitute WSASocket/socket to make +// possible execution of some network code in sandbox. +class NET_EXPORT PlatformSocketFactory { + public: + PlatformSocketFactory(); + virtual ~PlatformSocketFactory(); + + // Replace WSASocket/socket with given factory. The factory will be used by + // CreatePlatformSocket. + static void SetInstance(PlatformSocketFactory* factory); + + // Creates socket. See WSASocket/socket documentation of parameters. + virtual SocketDescriptor CreateSocket(int family, int type, int protocol) = 0; +}; + +// Creates socket. See WSASocket/socket documentation of parameters. +SocketDescriptor NET_EXPORT CreatePlatformSocket(int family, + int type, + int protocol); + +} // namespace net + +#endif // NET_SOCKET_SOCKET_DESCRIPTOR_H_ diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc index 8b2bdfccba3..78e9e7ce9c4 100644 --- a/chromium/net/socket/socket_test_util.cc +++ b/chromium/net/socket/socket_test_util.cc @@ -657,37 +657,39 @@ void MockClientSocketFactory::ResetNextMockIndexes() { mock_ssl_data_.ResetNextIndex(); } -DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( +scoped_ptr<DatagramClientSocket> +MockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); - MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockUDPClientSocket> socket( + new MockUDPClientSocket(data_provider, net_log)); + data_provider->set_socket(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } -StreamSocket* MockClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) { SocketDataProvider* data_provider = mock_data_.GetNext(); - MockTCPClientSocket* socket = - new MockTCPClientSocket(addresses, net_log, data_provider); - data_provider->set_socket(socket); - return socket; + scoped_ptr<MockTCPClientSocket> socket( + new MockTCPClientSocket(addresses, net_log, data_provider)); + data_provider->set_socket(socket.get()); + return socket.PassAs<StreamSocket>(); } -SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, +scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - mock_ssl_data_.GetNext()); - return socket; + return scoped_ptr<SSLClientSocket>( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); } void MockClientSocketFactory::ClearSSLSessionCache() { @@ -1278,7 +1280,7 @@ void DeterministicMockTCPClientSocket::OnConnectComplete( // static void MockSSLClientSocket::ConnectCallback( - MockSSLClientSocket *ssl_client_socket, + MockSSLClientSocket* ssl_client_socket, const CompletionCallback& callback, int rv) { if (rv == OK) @@ -1287,7 +1289,7 @@ void MockSSLClientSocket::ConnectCallback( } MockSSLClientSocket::MockSSLClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_port_pair, const SSLConfig& ssl_config, SSLSocketDataProvider* data) @@ -1295,7 +1297,7 @@ MockSSLClientSocket::MockSSLClientSocket( // Have to use the right BoundNetLog for LoadTimingInfo regression // tests. transport_socket->socket()->NetLog()), - transport_(transport_socket), + transport_(transport_socket.Pass()), data_(data), is_npn_state_set_(false), new_npn_value_(false), @@ -1664,10 +1666,10 @@ void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { } MockTransportClientSocketPool::MockConnectJob::MockConnectJob( - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, const CompletionCallback& callback) - : socket_(socket), + : socket_(socket.Pass()), handle_(handle), user_callback_(callback) { } @@ -1698,7 +1700,7 @@ void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { if (!socket_.get()) return; if (rv == OK) { - handle_->set_socket(socket_.release()); + handle_->SetSocket(socket_.Pass()); // Needed for socket pool tests that layer other sockets on top of mock // sockets. @@ -1730,6 +1732,7 @@ MockTransportClientSocketPool::MockTransportClientSocketPool( : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms, NULL, NULL, NULL), client_socket_factory_(socket_factory), + last_request_priority_(DEFAULT_PRIORITY), release_count_(0), cancel_count_(0) { } @@ -1740,9 +1743,11 @@ int MockTransportClientSocketPool::RequestSocket( const std::string& group_name, const void* socket_params, RequestPriority priority, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { - StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket( - AddressList(), net_log.net_log(), net::NetLog::Source()); - MockConnectJob* job = new MockConnectJob(socket, handle, callback); + last_request_priority_ = priority; + scoped_ptr<StreamSocket> socket = + client_socket_factory_->CreateTransportClientSocket( + AddressList(), net_log.net_log(), net::NetLog::Source()); + MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback); job_list_.push_back(job); handle->set_pool_id(1); return job->Connect(); @@ -1759,11 +1764,12 @@ void MockTransportClientSocketPool::CancelRequest(const std::string& group_name, } } -void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { +void MockTransportClientSocketPool::ReleaseSocket( + const std::string& group_name, + scoped_ptr<StreamSocket> socket, + int id) { EXPECT_EQ(1, id); release_count_++; - delete socket; } DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} @@ -1791,42 +1797,45 @@ MockSSLClientSocket* DeterministicMockClientSocketFactory:: return ssl_client_sockets_[index]; } -DatagramClientSocket* +scoped_ptr<DatagramClientSocket> DeterministicMockClientSocketFactory::CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, net::NetLog* net_log, const NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockUDPClientSocket* socket = - new DeterministicMockUDPClientSocket(net_log, data_provider); + scoped_ptr<DeterministicMockUDPClientSocket> socket( + new DeterministicMockUDPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); - udp_client_sockets().push_back(socket); - return socket; + udp_client_sockets().push_back(socket.get()); + return socket.PassAs<DatagramClientSocket>(); } -StreamSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket( +scoped_ptr<StreamSocket> +DeterministicMockClientSocketFactory::CreateTransportClientSocket( const AddressList& addresses, net::NetLog* net_log, const net::NetLog::Source& source) { DeterministicSocketData* data_provider = mock_data().GetNext(); - DeterministicMockTCPClientSocket* socket = - new DeterministicMockTCPClientSocket(net_log, data_provider); + scoped_ptr<DeterministicMockTCPClientSocket> socket( + new DeterministicMockTCPClientSocket(net_log, data_provider)); data_provider->set_delegate(socket->AsWeakPtr()); - tcp_client_sockets().push_back(socket); - return socket; + tcp_client_sockets().push_back(socket.get()); + return socket.PassAs<StreamSocket>(); } -SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( - ClientSocketHandle* transport_socket, +scoped_ptr<SSLClientSocket> +DeterministicMockClientSocketFactory::CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) { - MockSSLClientSocket* socket = - new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, - mock_ssl_data_.GetNext()); - ssl_client_sockets_.push_back(socket); - return socket; + scoped_ptr<MockSSLClientSocket> socket( + new MockSSLClientSocket(transport_socket.Pass(), + host_and_port, ssl_config, + mock_ssl_data_.GetNext())); + ssl_client_sockets_.push_back(socket.get()); + return socket.PassAs<SSLClientSocket>(); } void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { @@ -1859,8 +1868,9 @@ void MockSOCKSClientSocketPool::CancelRequest( } void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - return transport_pool_->ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id); } const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h index 6afe170299e..e4e56522c92 100644 --- a/chromium/net/socket/socket_test_util.h +++ b/chromium/net/socket/socket_test_util.h @@ -13,6 +13,7 @@ #include "base/basictypes.h" #include "base/callback.h" #include "base/logging.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/scoped_vector.h" #include "base/memory/weak_ptr.h" @@ -592,17 +593,17 @@ class MockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE; @@ -857,7 +858,7 @@ class DeterministicMockTCPClientSocket class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { public: MockSSLClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, SSLSocketDataProvider* socket); @@ -1001,11 +1002,12 @@ class ClientSocketPoolTest { ClientSocketPoolTest(); ~ClientSocketPoolTest(); - template <typename PoolType, typename SocketParams> - int StartRequestUsingPool(PoolType* socket_pool, - const std::string& group_name, - RequestPriority priority, - const scoped_refptr<SocketParams>& socket_params) { + template <typename PoolType> + int StartRequestUsingPool( + PoolType* socket_pool, + const std::string& group_name, + RequestPriority priority, + const scoped_refptr<typename PoolType::SocketParams>& socket_params) { DCHECK(socket_pool); TestSocketRequest* request = new TestSocketRequest(&request_order_, &completion_count_); @@ -1045,11 +1047,20 @@ class ClientSocketPoolTest { size_t completion_count_; }; +class MockTransportSocketParams + : public base::RefCounted<MockTransportSocketParams> { + private: + friend class base::RefCounted<MockTransportSocketParams>; + ~MockTransportSocketParams() {} +}; + class MockTransportClientSocketPool : public TransportClientSocketPool { public: + typedef MockTransportSocketParams SocketParams; + class MockConnectJob { public: - MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle, + MockConnectJob(scoped_ptr<StreamSocket> socket, ClientSocketHandle* handle, const CompletionCallback& callback); ~MockConnectJob(); @@ -1074,6 +1085,9 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { virtual ~MockTransportClientSocketPool(); + RequestPriority last_request_priority() const { + return last_request_priority_; + } int release_count() const { return release_count_; } int cancel_count() const { return cancel_count_; } @@ -1088,11 +1102,13 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) OVERRIDE; + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; private: ClientSocketFactory* client_socket_factory_; ScopedVector<MockConnectJob> job_list_; + RequestPriority last_request_priority_; int release_count_; int cancel_count_; @@ -1123,17 +1139,17 @@ class DeterministicMockClientSocketFactory : public ClientSocketFactory { } // ClientSocketFactory - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* net_log, const NetLog::Source& source) OVERRIDE; - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE; @@ -1170,7 +1186,8 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) OVERRIDE; + scoped_ptr<StreamSocket> socket, + int id) OVERRIDE; private: TransportClientSocketPool* const transport_pool_; diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc index c9d25bc3dcb..537b584a932 100644 --- a/chromium/net/socket/socks5_client_socket.cc +++ b/chromium/net/socket/socks5_client_socket.cc @@ -28,34 +28,18 @@ COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4); COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6); SOCKS5ClientSocket::SOCKS5ClientSocket( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info) : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, base::Unretained(this))), - transport_(transport_socket), + transport_(transport_socket.Pass()), next_state_(STATE_NONE), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), read_header_size(kReadHeaderSize), host_request_info_(req_info), - net_log_(transport_socket->socket()->NetLog()) { -} - -SOCKS5ClientSocket::SOCKS5ClientSocket( - StreamSocket* transport_socket, - const HostResolver::RequestInfo& req_info) - : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete, - base::Unretained(this))), - transport_(new ClientSocketHandle()), - next_state_(STATE_NONE), - completed_handshake_(false), - bytes_sent_(0), - bytes_received_(0), - read_header_size(kReadHeaderSize), - host_request_info_(req_info), - net_log_(transport_socket->NetLog()) { - transport_->set_socket(transport_socket); + net_log_(transport_->socket()->NetLog()) { } SOCKS5ClientSocket::~SOCKS5ClientSocket() { diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h index b955e8f42de..45216244f10 100644 --- a/chromium/net/socket/socks5_client_socket.h +++ b/chromium/net/socket/socks5_client_socket.h @@ -28,20 +28,13 @@ class BoundNetLog; // Currently no SOCKSv5 authentication is supported. class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { public: - // Takes ownership of the |transport_socket|, which should already be - // connected by the time Connect() is called. - // // |req_info| contains the hostname and port to which the socket above will // communicate to via the SOCKS layer. // // Although SOCKS 5 supports 3 different modes of addressing, we will // always pass it a hostname. This means the DNS resolving is done // proxy side. - SOCKS5ClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info); - - // Deprecated constructor (http://crbug.com/37810) that takes a StreamSocket. - SOCKS5ClientSocket(StreamSocket* transport_socket, + SOCKS5ClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info); // On destruction Disconnect() is called. diff --git a/chromium/net/socket/socks5_client_socket_unittest.cc b/chromium/net/socket/socks5_client_socket_unittest.cc index 717d858eef8..78f2ac433c3 100644 --- a/chromium/net/socket/socks5_client_socket_unittest.cc +++ b/chromium/net/socket/socks5_client_socket_unittest.cc @@ -32,13 +32,13 @@ class SOCKS5ClientSocketTest : public PlatformTest { public: SOCKS5ClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. - SOCKS5ClientSocket* BuildMockSocket(MockRead reads[], - size_t reads_count, - MockWrite writes[], - size_t writes_count, - const std::string& hostname, - int port, - NetLog* net_log); + scoped_ptr<SOCKS5ClientSocket> BuildMockSocket(MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count, + const std::string& hostname, + int port, + NetLog* net_log); virtual void SetUp(); @@ -47,6 +47,8 @@ class SOCKS5ClientSocketTest : public PlatformTest { CapturingNetLog net_log_; scoped_ptr<SOCKS5ClientSocket> user_sock_; AddressList address_list_; + // Filled in by BuildMockSocket() and owned by its return value + // (which |user_sock| is set to). StreamSocket* tcp_sock_; TestCompletionCallback callback_; scoped_ptr<MockHostResolver> host_resolver_; @@ -68,14 +70,18 @@ void SOCKS5ClientSocketTest::SetUp() { // Resolve the "localhost" AddressList used by the TCP connection to connect. HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080)); TestCompletionCallback callback; - int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(), - NULL, BoundNetLog()); + int rv = host_resolver_->Resolve(info, + DEFAULT_PRIORITY, + &address_list_, + callback.callback(), + NULL, + BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); ASSERT_EQ(OK, rv); } -SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( +scoped_ptr<SOCKS5ClientSocket> SOCKS5ClientSocketTest::BuildMockSocket( MockRead reads[], size_t reads_count, MockWrite writes[], @@ -94,8 +100,13 @@ SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( EXPECT_EQ(OK, rv); EXPECT_TRUE(tcp_sock_->IsConnected()); - return new SOCKS5ClientSocket(tcp_sock_, - HostResolver::RequestInfo(HostPortPair(hostname, port))); + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + // |connection| takes ownership of |tcp_sock_|, but keep a + // non-owning pointer to it. + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKS5ClientSocket>(new SOCKS5ClientSocket( + connection.Pass(), + HostResolver::RequestInfo(HostPortPair(hostname, port)))); } // Tests a complete SOCKS5 handshake and the disconnection. @@ -123,9 +134,9 @@ TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength), MockRead(ASYNC, payload_read.data(), payload_read.size()) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - "localhost", 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + "localhost", 80, &net_log_); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -195,9 +206,9 @@ TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, NULL); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(OK, rv); @@ -217,9 +228,9 @@ TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { // Create a SOCKS socket, with mock transport socket. MockWrite data_writes[] = {MockWrite()}; MockRead data_reads[] = {MockRead()}; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - large_host_name, 80, NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + large_host_name, 80, NULL); // Try to connect -- should fail (without having read/written anything to // the transport socket first) because the hostname is too long. @@ -253,9 +264,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -284,9 +295,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead(ASYNC, partial1, arraysize(partial1)), MockRead(ASYNC, partial2, arraysize(partial2)), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -314,9 +325,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { MockRead data_reads[] = { MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); CapturingNetLog::CapturedEntryList net_log_entries; @@ -345,9 +356,9 @@ TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { kSOCKS5OkResponseLength - kSplitPoint) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hostname, 80, &net_log_)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hostname, 80, &net_log_); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); CapturingNetLog::CapturedEntryList net_log_entries; diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc index c4bbd28c619..67089589cc5 100644 --- a/chromium/net/socket/socks_client_socket.cc +++ b/chromium/net/socket/socks_client_socket.cc @@ -55,32 +55,20 @@ struct SOCKS4ServerResponse { COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, socks4_server_response_struct_wrong_size); -SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver) - : transport_(transport_socket), +SOCKSClientSocket::SOCKSClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, + const HostResolver::RequestInfo& req_info, + RequestPriority priority, + HostResolver* host_resolver) + : transport_(transport_socket.Pass()), next_state_(STATE_NONE), completed_handshake_(false), bytes_sent_(0), bytes_received_(0), host_resolver_(host_resolver), host_request_info_(req_info), - net_log_(transport_socket->socket()->NetLog()) { -} - -SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver) - : transport_(new ClientSocketHandle()), - next_state_(STATE_NONE), - completed_handshake_(false), - bytes_sent_(0), - bytes_received_(0), - host_resolver_(host_resolver), - host_request_info_(req_info), - net_log_(transport_socket->NetLog()) { - transport_->set_socket(transport_socket); -} + priority_(priority), + net_log_(transport_->socket()->NetLog()) {} SOCKSClientSocket::~SOCKSClientSocket() { Disconnect(); @@ -283,7 +271,9 @@ int SOCKSClientSocket::DoResolveHost() { // addresses for the target host. host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); return host_resolver_.Resolve( - host_request_info_, &addresses_, + host_request_info_, + priority_, + &addresses_, base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), net_log_); } diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h index 3d4f9fc2771..d4f058a62b1 100644 --- a/chromium/net/socket/socks_client_socket.h +++ b/chromium/net/socket/socks_client_socket.h @@ -27,18 +27,11 @@ class BoundNetLog; // The SOCKS client socket implementation class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { public: - // Takes ownership of the |transport_socket|, which should already be - // connected by the time Connect() is called. - // // |req_info| contains the hostname and port to which the socket above will // communicate to via the socks layer. For testing the referrer is optional. - SOCKSClientSocket(ClientSocketHandle* transport_socket, - const HostResolver::RequestInfo& req_info, - HostResolver* host_resolver); - - // Deprecated constructor (http://crbug.com/37810) that takes a StreamSocket. - SOCKSClientSocket(StreamSocket* transport_socket, + SOCKSClientSocket(scoped_ptr<ClientSocketHandle> transport_socket, const HostResolver::RequestInfo& req_info, + RequestPriority priority, HostResolver* host_resolver); // On destruction Disconnect() is called. @@ -131,6 +124,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { SingleRequestHostResolver host_resolver_; AddressList addresses_; HostResolver::RequestInfo host_request_info_; + RequestPriority priority_; BoundNetLog net_log_; diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc index d740e5b9a0e..e11b7a48db5 100644 --- a/chromium/net/socket/socks_client_socket_pool.cc +++ b/chromium/net/socket/socks_client_socket_pool.cc @@ -21,8 +21,7 @@ namespace net { SOCKSSocketParams::SOCKSSocketParams( const scoped_refptr<TransportSocketParams>& proxy_server, bool socks_v5, - const HostPortPair& host_port_pair, - RequestPriority priority) + const HostPortPair& host_port_pair) : transport_params_(proxy_server), destination_(host_port_pair), socks_v5_(socks_v5) { @@ -30,7 +29,6 @@ SOCKSSocketParams::SOCKSSocketParams( ignore_limits_ = transport_params_->ignore_limits(); else ignore_limits_ = false; - destination_.set_priority(priority); } SOCKSSocketParams::~SOCKSSocketParams() {} @@ -41,13 +39,14 @@ static const int kSOCKSConnectJobTimeoutInSeconds = 30; SOCKSConnectJob::SOCKSConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<SOCKSSocketParams>& socks_params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), socks_params_(socks_params), transport_pool_(transport_pool), @@ -117,10 +116,12 @@ int SOCKSConnectJob::DoLoop(int result) { int SOCKSConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - return transport_socket_handle_->Init( - group_name(), socks_params_->transport_params(), - socks_params_->destination().priority(), callback_, transport_pool_, - net_log()); + return transport_socket_handle_->Init(group_name(), + socks_params_->transport_params(), + priority(), + callback_, + transport_pool_, + net_log()); } int SOCKSConnectJob::DoTransportConnectComplete(int result) { @@ -140,11 +141,12 @@ int SOCKSConnectJob::DoSOCKSConnect() { // Add a SOCKS connection on top of the tcp socket. if (socks_params_->is_socks_v5()) { - socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(), + socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), socks_params_->destination())); } else { - socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(), + socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), socks_params_->destination(), + priority(), resolver_)); } return socket_->Connect( @@ -157,7 +159,7 @@ int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { return result; } - set_socket(socket_.release()); + SetSocket(socket_.Pass()); return result; } @@ -166,17 +168,19 @@ int SOCKSConnectJob::ConnectInternal() { return DoLoop(OK); } -ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> +SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new SOCKSConnectJob(group_name, - request.params(), - ConnectionTimeout(), - transport_pool_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, + request.priority(), + request.params(), + ConnectionTimeout(), + transport_pool_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -193,7 +197,7 @@ SOCKSClientSocketPool::SOCKSClientSocketPool( TransportClientSocketPool* transport_pool, NetLog* net_log) : transport_pool_(transport_pool), - base_(max_sockets, max_sockets_per_group, histograms, + base_(this, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new SOCKSConnectJobFactory(transport_pool, @@ -201,13 +205,10 @@ SOCKSClientSocketPool::SOCKSClientSocketPool( net_log)) { // We should always have a |transport_pool_| except in unit tests. if (transport_pool_) - transport_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(transport_pool_); } SOCKSClientSocketPool::~SOCKSClientSocketPool() { - // We should always have a |transport_pool_| except in unit tests. - if (transport_pool_) - transport_pool_->RemoveLayeredPool(this); } int SOCKSClientSocketPool::RequestSocket( @@ -238,18 +239,15 @@ void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, } void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void SOCKSClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool SOCKSClientSocketPool::IsStalled() const { - return base_.IsStalled() || transport_pool_->IsStalled(); -} - void SOCKSClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -268,14 +266,6 @@ LoadState SOCKSClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void SOCKSClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void SOCKSClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -299,10 +289,24 @@ ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { return base_.histograms(); }; +bool SOCKSClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void SOCKSClientSocketPool::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void SOCKSClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); +} + bool SOCKSClientSocketPool::CloseOneIdleConnection() { if (base_.CloseOneIdleSocket()) return true; - return base_.CloseOneIdleConnectionInLayeredPool(); + return base_.CloseOneIdleConnectionInHigherLayeredPool(); } } // namespace net diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h index 86609a1a5a0..c6d5c8d0883 100644 --- a/chromium/net/socket/socks_client_socket_pool.h +++ b/chromium/net/socket/socks_client_socket_pool.h @@ -28,8 +28,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams : public base::RefCounted<SOCKSSocketParams> { public: SOCKSSocketParams(const scoped_refptr<TransportSocketParams>& proxy_server, - bool socks_v5, const HostPortPair& host_port_pair, - RequestPriority priority); + bool socks_v5, const HostPortPair& host_port_pair); const scoped_refptr<TransportSocketParams>& transport_params() const { return transport_params_; @@ -57,6 +56,7 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams class SOCKSConnectJob : public ConnectJob { public: SOCKSConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<SOCKSSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -105,8 +105,10 @@ class SOCKSConnectJob : public ConnectJob { }; class NET_EXPORT_PRIVATE SOCKSClientSocketPool - : public ClientSocketPool, public LayeredPool { + : public ClientSocketPool, public HigherLayeredPool { public: + typedef SOCKSSocketParams SocketParams; + SOCKSClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -134,13 +136,11 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -152,10 +152,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -165,7 +161,14 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; - // LayeredPool implementation. + // LowerLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + // HigherLayeredPool implementation. virtual bool CloseOneIdleConnection() OVERRIDE; private: @@ -183,7 +186,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool virtual ~SOCKSConnectJobFactory() {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -204,8 +207,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool DISALLOW_COPY_AND_ASSIGN(SOCKSClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(SOCKSClientSocketPool, SOCKSSocketParams); - } // namespace net #endif // NET_SOCKET_SOCKS_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc index 77440d36a19..4463e171f84 100644 --- a/chromium/net/socket/socks_client_socket_pool_unittest.cc +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -41,6 +41,25 @@ void TestLoadTimingInfo(const ClientSocketHandle& handle) { ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info); } + +scoped_refptr<TransportSocketParams> CreateProxyHostParams() { + return new TransportSocketParams( + HostPortPair("proxy", 80), false, false, + OnHostResolutionCallback()); +} + +scoped_refptr<SOCKSSocketParams> CreateSOCKSv4Params() { + return new SOCKSSocketParams( + CreateProxyHostParams(), false /* socks_v5 */, + HostPortPair("host", 80)); +} + +scoped_refptr<SOCKSSocketParams> CreateSOCKSv5Params() { + return new SOCKSSocketParams( + CreateProxyHostParams(), true /* socks_v5 */, + HostPortPair("host", 80)); +} + class SOCKSClientSocketPoolTest : public testing::Test { protected: class SOCKS5MockData { @@ -71,30 +90,24 @@ class SOCKSClientSocketPoolTest : public testing::Test { }; SOCKSClientSocketPoolTest() - : ignored_transport_socket_params_(new TransportSocketParams( - HostPortPair("proxy", 80), MEDIUM, false, false, - OnHostResolutionCallback())), - transport_histograms_("MockTCP"), + : transport_histograms_("MockTCP"), transport_socket_pool_( kMaxSockets, kMaxSocketsPerGroup, &transport_histograms_, &transport_client_socket_factory_), - ignored_socket_params_(new SOCKSSocketParams( - ignored_transport_socket_params_, true, HostPortPair("host", 80), - MEDIUM)), socks_histograms_("SOCKSUnitTest"), pool_(kMaxSockets, kMaxSocketsPerGroup, &socks_histograms_, - NULL, + &host_resolver_, &transport_socket_pool_, NULL) { } virtual ~SOCKSClientSocketPoolTest() {} - int StartRequest(const std::string& group_name, RequestPriority priority) { + int StartRequestV5(const std::string& group_name, RequestPriority priority) { return test_base_.StartRequestUsingPool( - &pool_, group_name, priority, ignored_socket_params_); + &pool_, group_name, priority, CreateSOCKSv5Params()); } int GetOrderOfRequest(size_t index) const { @@ -103,13 +116,12 @@ class SOCKSClientSocketPoolTest : public testing::Test { ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); } - scoped_refptr<TransportSocketParams> ignored_transport_socket_params_; ClientSocketPoolHistograms transport_histograms_; MockClientSocketFactory transport_client_socket_factory_; MockTransportClientSocketPool transport_socket_pool_; - scoped_refptr<SOCKSSocketParams> ignored_socket_params_; ClientSocketPoolHistograms socks_histograms_; + MockHostResolver host_resolver_; SOCKSClientSocketPool pool_; ClientSocketPoolTest test_base_; }; @@ -120,7 +132,7 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) { transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); @@ -128,13 +140,52 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) { TestLoadTimingInfo(handle); } +// Make sure that SOCKSConnectJob passes on its priority to its +// socket request on Init. +TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + SOCKS5MockData data(SYNCHRONOUS); + data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider( + data.data_provider()); + + ClientSocketHandle handle; + EXPECT_EQ(OK, + handle.Init("a", CreateSOCKSv5Params(), priority, + CompletionCallback(), &pool_, BoundNetLog())); + EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); + handle.socket()->Disconnect(); + } +} + +// Make sure that SOCKSConnectJob passes on its priority to its +// HostResolver request (for non-SOCKS5) on Init. +TEST_F(SOCKSClientSocketPoolTest, SetResolvePriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + SOCKS5MockData data(SYNCHRONOUS); + data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + transport_client_socket_factory_.AddSocketDataProvider( + data.data_provider()); + + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", CreateSOCKSv4Params(), priority, + CompletionCallback(), &pool_, BoundNetLog())); + EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); + EXPECT_EQ(priority, host_resolver_.last_request_priority()); + EXPECT_TRUE(handle.socket() == NULL); + } +} + TEST_F(SOCKSClientSocketPoolTest, Async) { SOCKS5MockData data(ASYNC); transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -153,7 +204,7 @@ TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) { transport_client_socket_factory_.AddSocketDataProvider(&socket_data); ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); @@ -167,7 +218,7 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -189,7 +240,7 @@ TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) { ClientSocketHandle handle; EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); @@ -209,7 +260,7 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(), + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -230,10 +281,10 @@ TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) { transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider()); EXPECT_EQ(0, transport_socket_pool_.cancel_count()); - int rv = StartRequest("a", LOW); + int rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); - rv = StartRequest("a", LOW); + rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); pool_.CancelRequest("a", (*requests())[0]->handle()); @@ -265,10 +316,10 @@ TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) { EXPECT_EQ(0, transport_socket_pool_.cancel_count()); EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = StartRequest("a", LOW); + int rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); - rv = StartRequest("a", LOW); + rv = StartRequestV5("a", LOW); EXPECT_EQ(ERR_IO_PENDING, rv); pool_.CancelRequest("a", (*requests())[0]->handle()); diff --git a/chromium/net/socket/socks_client_socket_unittest.cc b/chromium/net/socket/socks_client_socket_unittest.cc index 7a8faf69856..f361244feff 100644 --- a/chromium/net/socket/socks_client_socket_unittest.cc +++ b/chromium/net/socket/socks_client_socket_unittest.cc @@ -4,6 +4,7 @@ #include "net/socket/socks_client_socket.h" +#include "base/memory/scoped_ptr.h" #include "net/base/address_list.h" #include "net/base/net_log.h" #include "net/base/net_log_unittest.h" @@ -27,16 +28,19 @@ class SOCKSClientSocketTest : public PlatformTest { public: SOCKSClientSocketTest(); // Create a SOCKSClientSocket on top of a MockSocket. - SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count, - MockWrite writes[], size_t writes_count, - HostResolver* host_resolver, - const std::string& hostname, int port, - NetLog* net_log); + scoped_ptr<SOCKSClientSocket> BuildMockSocket( + MockRead reads[], size_t reads_count, + MockWrite writes[], size_t writes_count, + HostResolver* host_resolver, + const std::string& hostname, int port, + NetLog* net_log); virtual void SetUp(); protected: scoped_ptr<SOCKSClientSocket> user_sock_; AddressList address_list_; + // Filled in by BuildMockSocket() and owned by its return value + // (which |user_sock| is set to). StreamSocket* tcp_sock_; TestCompletionCallback callback_; scoped_ptr<MockHostResolver> host_resolver_; @@ -52,7 +56,7 @@ void SOCKSClientSocketTest::SetUp() { PlatformTest::SetUp(); } -SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( +scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( MockRead reads[], size_t reads_count, MockWrite writes[], @@ -73,9 +77,15 @@ SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket( EXPECT_EQ(OK, rv); EXPECT_TRUE(tcp_sock_->IsConnected()); - return new SOCKSClientSocket(tcp_sock_, + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + // |connection| takes ownership of |tcp_sock_|, but keep a + // non-owning pointer to it. + connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); + return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket( + connection.Pass(), HostResolver::RequestInfo(HostPortPair(hostname, port)), - host_resolver); + DEFAULT_PRIORITY, + host_resolver)); } // Implementation of HostResolver that never completes its resolve request. @@ -86,6 +96,7 @@ class HangingHostResolverWithCancel : public HostResolver { HangingHostResolverWithCancel() : outstanding_request_(NULL) {} virtual int Resolve(const RequestInfo& info, + RequestPriority priority, AddressList* addresses, const CompletionCallback& callback, RequestHandle* out_req, @@ -134,11 +145,11 @@ TEST_F(SOCKSClientSocketTest, CompleteHandshake) { MockRead(ASYNC, payload_read.data(), payload_read.size()) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); // At this state the TCP connection is completed but not the SOCKS handshake. EXPECT_TRUE(tcp_sock_->IsConnected()); @@ -210,11 +221,11 @@ TEST_F(SOCKSClientSocketTest, HandshakeFailures) { arraysize(tests[i].fail_reply)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -247,11 +258,11 @@ TEST_F(SOCKSClientSocketTest, PartialServerReads) { MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -285,11 +296,11 @@ TEST_F(SOCKSClientSocketTest, PartialClientWrites) { MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -317,11 +328,11 @@ TEST_F(SOCKSClientSocketTest, FailedSocketRead) { MockRead(SYNCHRONOUS, 0) }; CapturingNetLog log; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - host_resolver_.get(), - "localhost", 80, - &log)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + host_resolver_.get(), + "localhost", 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -347,11 +358,11 @@ TEST_F(SOCKSClientSocketTest, FailedDNS) { CapturingNetLog log; - user_sock_.reset(BuildMockSocket(NULL, 0, - NULL, 0, - host_resolver_.get(), - hostname, 80, - &log)); + user_sock_ = BuildMockSocket(NULL, 0, + NULL, 0, + host_resolver_.get(), + hostname, 80, + &log); int rv = user_sock_->Connect(callback_.callback()); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -378,11 +389,11 @@ TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) }; MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) }; - user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), - data_writes, arraysize(data_writes), - hanging_resolver.get(), - "foo", 80, - NULL)); + user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), + data_writes, arraysize(data_writes), + hanging_resolver.get(), + "foo", 80, + NULL); // Start connecting (will get stuck waiting for the host to resolve). int rv = user_sock_->Connect(callback_.callback()); diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc index f374dedcf80..0de7cfb9060 100644 --- a/chromium/net/socket/ssl_client_socket_nss.cc +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -380,20 +380,16 @@ void PeerCertificateChain::Reset(PRFileDesc* nss_fd) { if (nss_fd == NULL) return; - unsigned int num_certs = 0; - SECStatus rv = SSL_PeerCertificateChain(nss_fd, NULL, &num_certs, 0); - DCHECK_EQ(SECSuccess, rv); - + CERTCertList* list = SSL_PeerCertificateChain(nss_fd); // The handshake on |nss_fd| may not have completed. - if (num_certs == 0) + if (list == NULL) return; - certs_.resize(num_certs); - const unsigned int expected_num_certs = num_certs; - rv = SSL_PeerCertificateChain(nss_fd, vector_as_array(&certs_), - &num_certs, expected_num_certs); - DCHECK_EQ(SECSuccess, rv); - DCHECK_EQ(expected_num_certs, num_certs); + for (CERTCertListNode* node = CERT_LIST_HEAD(list); + !CERT_LIST_END(node, list); node = CERT_LIST_NEXT(node)) { + certs_.push_back(CERT_DupCertificate(node->cert)); + } + CERT_DestroyCertList(list); } std::vector<base::StringPiece> @@ -1291,6 +1287,19 @@ SECStatus SSLClientSocketNSS::Core::OwnAuthCertHandler( // Start with it. SSL_OptionSet(socket, SSL_ENABLE_FALSE_START, PR_FALSE); } + } else { + // Disallow the server certificate to change in a renegotiation. + CERTCertificate* old_cert = core->nss_handshake_state_.server_cert_chain[0]; + ScopedCERTCertificate new_cert(SSL_PeerCertificate(socket)); + if (new_cert->derCert.len != old_cert->derCert.len || + memcmp(new_cert->derCert.data, old_cert->derCert.data, + new_cert->derCert.len) != 0) { + // NSS doesn't have an error code that indicates the server certificate + // changed. Borrow SSL_ERROR_WRONG_CERTIFICATE (which NSS isn't using) + // for this purpose. + PORT_SetError(SSL_ERROR_WRONG_CERTIFICATE); + return SECFailure; + } } // Tell NSS to not verify the certificate. @@ -2598,7 +2607,7 @@ int SSLClientSocketNSS::Core::DoGetDomainBoundCert(const std::string& host) { weak_net_log_->BeginEvent(NetLog::TYPE_SSL_GET_DOMAIN_BOUND_CERT); - int rv = server_bound_cert_service_->GetDomainBoundCert( + int rv = server_bound_cert_service_->GetOrCreateDomainBoundCert( host, &domain_bound_private_key_, &domain_bound_cert_, @@ -2751,12 +2760,12 @@ void SSLClientSocketNSS::Core::SetChannelIDProvided() { SSLClientSocketNSS::SSLClientSocketNSS( base::SequencedTaskRunner* nss_task_runner, - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) : nss_task_runner_(nss_task_runner), - transport_(transport_socket), + transport_(transport_socket.Pass()), host_and_port_(host_and_port), ssl_config_(ssl_config), cert_verifier_(context.cert_verifier), @@ -2765,7 +2774,7 @@ SSLClientSocketNSS::SSLClientSocketNSS( completed_handshake_(false), next_handshake_state_(STATE_NONE), nss_fd_(NULL), - net_log_(transport_socket->socket()->NetLog()), + net_log_(transport_->socket()->NetLog()), transport_security_state_(context.transport_security_state), valid_thread_id_(base::kInvalidThreadId) { EnterFunction(""); @@ -3141,7 +3150,8 @@ int SSLClientSocketNSS::InitializeSSLOptions() { net_log_, "SSL_OptionSet", "SSL_ENABLE_SESSION_TICKETS"); } - rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START, PR_FALSE); + rv = SSL_OptionSet(nss_fd_, SSL_ENABLE_FALSE_START, + ssl_config_.false_start_enabled); if (rv != SECSuccess) LogFailedNSSFunction(net_log_, "SSL_OptionSet", "SSL_ENABLE_FALSE_START"); diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h index fed8ef706b5..b41d28d74a8 100644 --- a/chromium/net/socket/ssl_client_socket_nss.h +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -59,7 +59,7 @@ class SSLClientSocketNSS : public SSLClientSocket { // behaviour is desired, for performance or compatibility, the current task // runner should be supplied instead. SSLClientSocketNSS(base::SequencedTaskRunner* nss_task_runner, - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc index 1431bc61486..416ab87bc4b 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.cc +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -425,7 +425,7 @@ void SSLClientSocket::ClearSessionCache() { } SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( - ClientSocketHandle* transport_socket, + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) @@ -439,14 +439,14 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( cert_verifier_(context.cert_verifier), ssl_(NULL), transport_bio_(NULL), - transport_(transport_socket), + transport_(transport_socket.Pass()), host_and_port_(host_and_port), ssl_config_(ssl_config), ssl_session_cache_shard_(context.ssl_session_cache_shard), trying_cached_session_(false), next_handshake_state_(STATE_NONE), npn_status_(kNextProtoUnsupported), - net_log_(transport_socket->socket()->NetLog()) { + net_log_(transport_->socket()->NetLog()) { } SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { @@ -532,9 +532,11 @@ bool SSLClientSocketOpenSSL::Init() { STACK_OF(SSL_CIPHER)* ciphers = SSL_get_ciphers(ssl_); DCHECK(ciphers); // See SSLConfig::disabled_cipher_suites for description of the suites - // disabled by default. Note that !SHA384 only removes HMAC-SHA384 cipher - // suites, not GCM cipher suites with SHA384 as the handshake hash. - std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA384:!aECDH"); + // disabled by default. Note that !SHA256 and !SHA384 only remove HMAC-SHA256 + // and HMAC-SHA384 cipher suites, not GCM cipher suites with SHA256 or SHA384 + // as the handshake hash. + std::string command("DEFAULT:!NULL:!aNULL:!IDEA:!FZA:!SRP:!SHA256:!SHA384:" + "!aECDH:!AESGCM+AES256"); // Walk through all the installed ciphers, seeing if any need to be // appended to the cipher removal |command|. for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); ++i) { diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h index 520f432b8bc..f66d95cc69d 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.h +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -41,7 +41,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // The given hostname will be compared with the name(s) in the server's // certificate during the SSL handshake. ssl_config specifies the SSL // settings. - SSLClientSocketOpenSSL(ClientSocketHandle* transport_socket, + SSLClientSocketOpenSSL(scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context); diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc index 7a37cdc1187..24c06059be5 100644 --- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc @@ -67,7 +67,7 @@ bool LoadPrivateKeyOpenSSL( const base::FilePath& filepath, OpenSSLClientKeyStore::ScopedEVP_PKEY* pkey) { std::string data; - if (!file_util::ReadFileToString(filepath, &data)) { + if (!base::ReadFileToString(filepath, &data)) { LOG(ERROR) << "Could not read private key file: " << filepath.value() << ": " << strerror(errno); return false; @@ -107,11 +107,13 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { } protected: - SSLClientSocket* CreateSSLClientSocket( - StreamSocket* transport_socket, + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config) { - return socket_factory_->CreateSSLClientSocket(transport_socket, + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket(connection.Pass(), host_and_port, ssl_config, context_); @@ -164,9 +166,9 @@ class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { // itself was a success. bool CreateAndConnectSSLClientSocket(SSLConfig& ssl_config, int* result) { - sock_.reset(CreateSSLClientSocket(transport_.release(), - test_server_->host_port_pair(), - ssl_config)); + sock_ = CreateSSLClientSocket(transport_.Pass(), + test_server_->host_port_pair(), + ssl_config); if (sock_->IsConnected()) { LOG(ERROR) << "SSL Socket prematurely connected"; diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc index fed268d4ee4..5d574b7edda 100644 --- a/chromium/net/socket/ssl_client_socket_pool.cc +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -26,20 +26,18 @@ namespace net { SSLSocketParams::SSLSocketParams( - const scoped_refptr<TransportSocketParams>& transport_params, - const scoped_refptr<SOCKSSocketParams>& socks_params, + const scoped_refptr<TransportSocketParams>& direct_params, + const scoped_refptr<SOCKSSocketParams>& socks_proxy_params, const scoped_refptr<HttpProxySocketParams>& http_proxy_params, - ProxyServer::Scheme proxy, const HostPortPair& host_and_port, const SSLConfig& ssl_config, PrivacyMode privacy_mode, int load_flags, bool force_spdy_over_ssl, bool want_spdy_over_npn) - : transport_params_(transport_params), + : direct_params_(direct_params), + socks_proxy_params_(socks_proxy_params), http_proxy_params_(http_proxy_params), - socks_params_(socks_params), - proxy_(proxy), host_and_port_(host_and_port), ssl_config_(ssl_config), privacy_mode_(privacy_mode), @@ -47,39 +45,60 @@ SSLSocketParams::SSLSocketParams( force_spdy_over_ssl_(force_spdy_over_ssl), want_spdy_over_npn_(want_spdy_over_npn), ignore_limits_(false) { - switch (proxy_) { - case ProxyServer::SCHEME_DIRECT: - DCHECK(transport_params_.get() != NULL); - DCHECK(http_proxy_params_.get() == NULL); - DCHECK(socks_params_.get() == NULL); - ignore_limits_ = transport_params_->ignore_limits(); - break; - case ProxyServer::SCHEME_HTTP: - case ProxyServer::SCHEME_HTTPS: - DCHECK(transport_params_.get() == NULL); - DCHECK(http_proxy_params_.get() != NULL); - DCHECK(socks_params_.get() == NULL); - ignore_limits_ = http_proxy_params_->ignore_limits(); - break; - case ProxyServer::SCHEME_SOCKS4: - case ProxyServer::SCHEME_SOCKS5: - DCHECK(transport_params_.get() == NULL); - DCHECK(http_proxy_params_.get() == NULL); - DCHECK(socks_params_.get() != NULL); - ignore_limits_ = socks_params_->ignore_limits(); - break; - default: - LOG(DFATAL) << "unknown proxy type"; - break; + if (direct_params_) { + DCHECK(!socks_proxy_params_); + DCHECK(!http_proxy_params_); + ignore_limits_ = direct_params_->ignore_limits(); + } else if (socks_proxy_params_) { + DCHECK(!http_proxy_params_); + ignore_limits_ = socks_proxy_params_->ignore_limits(); + } else { + DCHECK(http_proxy_params_); + ignore_limits_ = http_proxy_params_->ignore_limits(); } } SSLSocketParams::~SSLSocketParams() {} +SSLSocketParams::ConnectionType SSLSocketParams::GetConnectionType() const { + if (direct_params_) { + DCHECK(!socks_proxy_params_); + DCHECK(!http_proxy_params_); + return DIRECT; + } + + if (socks_proxy_params_) { + DCHECK(!http_proxy_params_); + return SOCKS_PROXY; + } + + DCHECK(http_proxy_params_); + return HTTP_PROXY; +} + +const scoped_refptr<TransportSocketParams>& +SSLSocketParams::GetDirectConnectionParams() const { + DCHECK_EQ(GetConnectionType(), DIRECT); + return direct_params_; +} + +const scoped_refptr<SOCKSSocketParams>& +SSLSocketParams::GetSocksProxyConnectionParams() const { + DCHECK_EQ(GetConnectionType(), SOCKS_PROXY); + return socks_proxy_params_; +} + +const scoped_refptr<HttpProxySocketParams>& +SSLSocketParams::GetHttpProxyConnectionParams() const { + DCHECK_EQ(GetConnectionType(), HTTP_PROXY); + return http_proxy_params_; +} + // Timeout for the SSL handshake portion of the connect. static const int kSSLHandshakeTimeoutInSeconds = 30; SSLConnectJob::SSLConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<SSLSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -92,6 +111,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, NetLog* net_log) : ConnectJob(group_name, timeout_duration, + priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), @@ -201,12 +221,14 @@ int SSLConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - scoped_refptr<TransportSocketParams> transport_params = - params_->transport_params(); - return transport_socket_handle_->Init( - group_name(), transport_params, - transport_params->destination().priority(), callback_, transport_pool_, - net_log()); + scoped_refptr<TransportSocketParams> direct_params = + params_->GetDirectConnectionParams(); + return transport_socket_handle_->Init(group_name(), + direct_params, + priority(), + callback_, + transport_pool_, + net_log()); } int SSLConnectJob::DoTransportConnectComplete(int result) { @@ -220,10 +242,14 @@ int SSLConnectJob::DoSOCKSConnect() { DCHECK(socks_pool_); next_state_ = STATE_SOCKS_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - scoped_refptr<SOCKSSocketParams> socks_params = params_->socks_params(); - return transport_socket_handle_->Init( - group_name(), socks_params, socks_params->destination().priority(), - callback_, socks_pool_, net_log()); + scoped_refptr<SOCKSSocketParams> socks_proxy_params = + params_->GetSocksProxyConnectionParams(); + return transport_socket_handle_->Init(group_name(), + socks_proxy_params, + priority(), + callback_, + socks_pool_, + net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { @@ -239,11 +265,13 @@ int SSLConnectJob::DoTunnelConnect() { transport_socket_handle_.reset(new ClientSocketHandle()); scoped_refptr<HttpProxySocketParams> http_proxy_params = - params_->http_proxy_params(); - return transport_socket_handle_->Init( - group_name(), http_proxy_params, - http_proxy_params->destination().priority(), callback_, http_proxy_pool_, - net_log()); + params_->GetHttpProxyConnectionParams(); + return transport_socket_handle_->Init(group_name(), + http_proxy_params, + priority(), + callback_, + http_proxy_pool_, + net_log()); } int SSLConnectJob::DoTunnelConnectComplete(int result) { @@ -287,11 +315,11 @@ int SSLConnectJob::DoSSLConnect() { connect_timing_.ssl_start = base::TimeTicks::Now(); - ssl_socket_.reset(client_socket_factory_->CreateSSLClientSocket( - transport_socket_handle_.release(), + ssl_socket_ = client_socket_factory_->CreateSSLClientSocket( + transport_socket_handle_.Pass(), params_->host_and_port(), params_->ssl_config(), - context_)); + context_); return ssl_socket_->Connect(callback_); } @@ -410,7 +438,7 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { } if (result == OK || IsCertificateError(result)) { - set_socket(ssl_socket_.release()); + SetSocket(ssl_socket_.PassAs<StreamSocket>()); } else if (result == ERR_SSL_CLIENT_AUTH_CERT_NEEDED) { error_response_info_.cert_request_info = new SSLCertRequestInfo; ssl_socket_->GetSSLCertRequestInfo( @@ -420,23 +448,22 @@ int SSLConnectJob::DoSSLConnectComplete(int result) { return result; } -int SSLConnectJob::ConnectInternal() { - switch (params_->proxy()) { - case ProxyServer::SCHEME_DIRECT: - next_state_ = STATE_TRANSPORT_CONNECT; - break; - case ProxyServer::SCHEME_HTTP: - case ProxyServer::SCHEME_HTTPS: - next_state_ = STATE_TUNNEL_CONNECT; - break; - case ProxyServer::SCHEME_SOCKS4: - case ProxyServer::SCHEME_SOCKS5: - next_state_ = STATE_SOCKS_CONNECT; - break; - default: - NOTREACHED() << "unknown proxy type"; - break; +SSLConnectJob::State SSLConnectJob::GetInitialState( + SSLSocketParams::ConnectionType connection_type) { + switch (connection_type) { + case SSLSocketParams::DIRECT: + return STATE_TRANSPORT_CONNECT; + case SSLSocketParams::HTTP_PROXY: + return STATE_TUNNEL_CONNECT; + case SSLSocketParams::SOCKS_PROXY: + return STATE_SOCKS_CONNECT; } + NOTREACHED(); + return STATE_NONE; +} + +int SSLConnectJob::ConnectInternal() { + next_state_ = GetInitialState(params_->GetConnectionType()); return DoLoop(OK); } @@ -491,7 +518,7 @@ SSLClientSocketPool::SSLClientSocketPool( : transport_pool_(transport_pool), socks_pool_(socks_pool), http_proxy_pool_(http_proxy_pool), - base_(max_sockets, max_sockets_per_group, histograms, + base_(this, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new SSLConnectJobFactory(transport_pool, @@ -509,32 +536,28 @@ SSLClientSocketPool::SSLClientSocketPool( if (ssl_config_service_.get()) ssl_config_service_->AddObserver(this); if (transport_pool_) - transport_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(transport_pool_); if (socks_pool_) - socks_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(socks_pool_); if (http_proxy_pool_) - http_proxy_pool_->AddLayeredPool(this); + base_.AddLowerLayeredPool(http_proxy_pool_); } SSLClientSocketPool::~SSLClientSocketPool() { - if (http_proxy_pool_) - http_proxy_pool_->RemoveLayeredPool(this); - if (socks_pool_) - socks_pool_->RemoveLayeredPool(this); - if (transport_pool_) - transport_pool_->RemoveLayeredPool(this); if (ssl_config_service_.get()) ssl_config_service_->RemoveObserver(this); } -ConnectJob* SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( +scoped_ptr<ConnectJob> +SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new SSLConnectJob(group_name, request.params(), ConnectionTimeout(), - transport_pool_, socks_pool_, http_proxy_pool_, - client_socket_factory_, host_resolver_, - context_, delegate, net_log_); + return scoped_ptr<ConnectJob>( + new SSLConnectJob(group_name, request.priority(), request.params(), + ConnectionTimeout(), transport_pool_, socks_pool_, + http_proxy_pool_, client_socket_factory_, + host_resolver_, context_, delegate, net_log_)); } base::TimeDelta @@ -572,21 +595,15 @@ void SSLClientSocketPool::CancelRequest(const std::string& group_name, } void SSLClientSocketPool::ReleaseSocket(const std::string& group_name, - StreamSocket* socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + scoped_ptr<StreamSocket> socket, + int id) { + base_.ReleaseSocket(group_name, socket.Pass(), id); } void SSLClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool SSLClientSocketPool::IsStalled() const { - return base_.IsStalled() || - (transport_pool_ && transport_pool_->IsStalled()) || - (socks_pool_ && socks_pool_->IsStalled()) || - (http_proxy_pool_ && http_proxy_pool_->IsStalled()); -} - void SSLClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -605,14 +622,6 @@ LoadState SSLClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void SSLClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void SSLClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* SSLClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -648,14 +657,27 @@ ClientSocketPoolHistograms* SSLClientSocketPool::histograms() const { return base_.histograms(); } -void SSLClientSocketPool::OnSSLConfigChanged() { - FlushWithError(ERR_NETWORK_CHANGED); +bool SSLClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void SSLClientSocketPool::AddHigherLayeredPool(HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void SSLClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); } bool SSLClientSocketPool::CloseOneIdleConnection() { if (base_.CloseOneIdleSocket()) return true; - return base_.CloseOneIdleConnectionInLayeredPool(); + return base_.CloseOneIdleConnectionInHigherLayeredPool(); +} + +void SSLClientSocketPool::OnSSLConfigChanged() { + FlushWithError(ERR_NETWORK_CHANGED); } } // namespace net diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h index bc54bc92f9a..ec62eb01f46 100644 --- a/chromium/net/socket/ssl_client_socket_pool.h +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -13,7 +13,6 @@ #include "net/base/privacy_mode.h" #include "net/dns/host_resolver.h" #include "net/http/http_response_info.h" -#include "net/proxy/proxy_server.h" #include "net/socket/client_socket_pool.h" #include "net/socket/client_socket_pool_base.h" #include "net/socket/client_socket_pool_histograms.h" @@ -35,32 +34,39 @@ class TransportClientSocketPool; class TransportSecurityState; class TransportSocketParams; -// SSLSocketParams only needs the socket params for the transport socket -// that will be used (denoted by |proxy|). class NET_EXPORT_PRIVATE SSLSocketParams : public base::RefCounted<SSLSocketParams> { public: - SSLSocketParams(const scoped_refptr<TransportSocketParams>& transport_params, - const scoped_refptr<SOCKSSocketParams>& socks_params, - const scoped_refptr<HttpProxySocketParams>& http_proxy_params, - ProxyServer::Scheme proxy, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config, - PrivacyMode privacy_mode, - int load_flags, - bool force_spdy_over_ssl, - bool want_spdy_over_npn); - - const scoped_refptr<TransportSocketParams>& transport_params() { - return transport_params_; - } - const scoped_refptr<HttpProxySocketParams>& http_proxy_params() { - return http_proxy_params_; - } - const scoped_refptr<SOCKSSocketParams>& socks_params() { - return socks_params_; - } - ProxyServer::Scheme proxy() const { return proxy_; } + enum ConnectionType { DIRECT, SOCKS_PROXY, HTTP_PROXY }; + + // Exactly one of |direct_params|, |socks_proxy_params|, and + // |http_proxy_params| must be non-NULL. + SSLSocketParams( + const scoped_refptr<TransportSocketParams>& direct_params, + const scoped_refptr<SOCKSSocketParams>& socks_proxy_params, + const scoped_refptr<HttpProxySocketParams>& http_proxy_params, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config, + PrivacyMode privacy_mode, + int load_flags, + bool force_spdy_over_ssl, + bool want_spdy_over_npn); + + // Returns the type of the underlying connection. + ConnectionType GetConnectionType() const; + + // Must be called only when GetConnectionType() returns DIRECT. + const scoped_refptr<TransportSocketParams>& + GetDirectConnectionParams() const; + + // Must be called only when GetConnectionType() returns SOCKS_PROXY. + const scoped_refptr<SOCKSSocketParams>& + GetSocksProxyConnectionParams() const; + + // Must be called only when GetConnectionType() returns HTTP_PROXY. + const scoped_refptr<HttpProxySocketParams>& + GetHttpProxyConnectionParams() const; + const HostPortPair& host_and_port() const { return host_and_port_; } const SSLConfig& ssl_config() const { return ssl_config_; } PrivacyMode privacy_mode() const { return privacy_mode_; } @@ -73,10 +79,9 @@ class NET_EXPORT_PRIVATE SSLSocketParams friend class base::RefCounted<SSLSocketParams>; ~SSLSocketParams(); - const scoped_refptr<TransportSocketParams> transport_params_; + const scoped_refptr<TransportSocketParams> direct_params_; + const scoped_refptr<SOCKSSocketParams> socks_proxy_params_; const scoped_refptr<HttpProxySocketParams> http_proxy_params_; - const scoped_refptr<SOCKSSocketParams> socks_params_; - const ProxyServer::Scheme proxy_; const HostPortPair host_and_port_; const SSLConfig ssl_config_; const PrivacyMode privacy_mode_; @@ -94,6 +99,7 @@ class SSLConnectJob : public ConnectJob { public: SSLConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<SSLSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -138,6 +144,10 @@ class SSLConnectJob : public ConnectJob { int DoSSLConnect(); int DoSSLConnectComplete(int result); + // Returns the initial state for the state machine based on the + // |connection_type|. + static State GetInitialState(SSLSocketParams::ConnectionType connection_type); + // Starts the SSL connection process. Returns OK on success and // ERR_IO_PENDING if it cannot immediately service the request. // Otherwise, it returns a net error code. @@ -164,9 +174,11 @@ class SSLConnectJob : public ConnectJob { class NET_EXPORT_PRIVATE SSLClientSocketPool : public ClientSocketPool, - public LayeredPool, + public HigherLayeredPool, public SSLConfigService::Observer { public: + typedef SSLSocketParams SocketParams; + // Only the pools that will be used are required. i.e. if you never // try to create an SSL over SOCKS socket, |socks_pool| may be NULL. SSLClientSocketPool( @@ -204,13 +216,11 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; - virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; @@ -222,10 +232,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -235,7 +241,14 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; - // LayeredPool implementation. + // LowerLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + + // HigherLayeredPool implementation. virtual bool CloseOneIdleConnection() OVERRIDE; private: @@ -261,7 +274,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool virtual ~SSLConnectJobFactory() {} // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -290,8 +303,6 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams); - } // namespace net #endif // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc index 280f6e7af1a..8aecb98dd85 100644 --- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -85,7 +85,6 @@ class SSLClientSocketPoolTest session_(CreateNetworkSession()), direct_transport_socket_params_( new TransportSocketParams(HostPortPair("host", 443), - MEDIUM, false, false, OnHostResolutionCallback())), @@ -96,15 +95,13 @@ class SSLClientSocketPoolTest &socket_factory_), proxy_transport_socket_params_( new TransportSocketParams(HostPortPair("proxy", 443), - MEDIUM, false, false, OnHostResolutionCallback())), socks_socket_params_( new SOCKSSocketParams(proxy_transport_socket_params_, true, - HostPortPair("sockshost", 443), - MEDIUM)), + HostPortPair("sockshost", 443))), socks_histograms_("MockSOCKS"), socks_socket_pool_(kMaxSockets, kMaxSocketsPerGroup, @@ -159,7 +156,6 @@ class SSLClientSocketPoolTest : NULL, proxy == ProxyServer::SCHEME_SOCKS5 ? socks_socket_params_ : NULL, proxy == ProxyServer::SCHEME_HTTP ? http_proxy_socket_params_ : NULL, - proxy, HostPortPair("host", 443), ssl_config_, kPrivacyModeDisabled, @@ -294,6 +290,30 @@ TEST_P(SSLClientSocketPoolTest, BasicDirect) { TestLoadTimingInfo(handle); } +// Make sure that SSLConnectJob passes on its priority to its +// socket request on Init (for the DIRECT case). +TEST_P(SSLClientSocketPoolTest, SetSocketRequestPriorityOnInitDirect) { + CreatePool(true /* tcp pool */, false, false); + scoped_refptr<SSLSocketParams> params = + SSLParams(ProxyServer::SCHEME_DIRECT, false); + + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", params, priority, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); + handle.socket()->Disconnect(); + } +} + TEST_P(SSLClientSocketPoolTest, BasicDirectAsync) { StaticSocketDataProvider data; socket_factory_.AddSocketDataProvider(&data); @@ -547,6 +567,26 @@ TEST_P(SSLClientSocketPoolTest, SOCKSBasic) { TestLoadTimingInfo(handle); } +// Make sure that SSLConnectJob passes on its priority to its +// transport socket on Init (for the SOCKS_PROXY case). +TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitSOCKS) { + StaticSocketDataProvider data; + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = + SSLParams(ProxyServer::SCHEME_SOCKS5, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority()); +} + TEST_P(SSLClientSocketPoolTest, SOCKSBasicAsync) { StaticSocketDataProvider data; socket_factory_.AddSocketDataProvider(&data); @@ -648,6 +688,38 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) { TestLoadTimingInfoNoDns(handle); } +// Make sure that SSLConnectJob passes on its priority to its +// transport socket on Init (for the HTTP_PROXY case). +TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitHTTP) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, + "CONNECT host:80 HTTP/1.1\r\n" + "Host: host\r\n" + "Proxy-Connection: keep-alive\r\n" + "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), + }; + MockRead reads[] = { + MockRead(SYNCHRONOUS, "HTTP/1.1 200 Connection Established\r\n\r\n"), + }; + StaticSocketDataProvider data(reads, arraysize(reads), writes, + arraysize(writes)); + data.set_connect_data(MockConnect(SYNCHRONOUS, OK)); + socket_factory_.AddSocketDataProvider(&data); + AddAuthToCache(); + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + socket_factory_.AddSSLSocketDataProvider(&ssl); + + CreatePool(false, true /* http proxy pool */, true /* socks pool */); + scoped_refptr<SSLSocketParams> params = + SSLParams(ProxyServer::SCHEME_HTTP, false); + + ClientSocketHandle handle; + TestCompletionCallback callback; + EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(), + pool_.get(), BoundNetLog())); + EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority()); +} + TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) { MockWrite writes[] = { MockWrite("CONNECT host:80 HTTP/1.1\r\n" @@ -746,8 +818,12 @@ TEST_P(SSLClientSocketPoolTest, IPPooling) { // This test requires that the HostResolver cache be populated. Normal // code would have done this already, but we do it manually. HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); - host_resolver_.Resolve(info, &test_hosts[i].addresses, CompletionCallback(), - NULL, BoundNetLog()); + host_resolver_.Resolve(info, + DEFAULT_PRIORITY, + &test_hosts[i].addresses, + CompletionCallback(), + NULL, + BoundNetLog()); // Setup a SpdySessionKey test_hosts[i].key = SpdySessionKey( @@ -802,8 +878,12 @@ void SSLClientSocketPoolTest::TestIPPoolingDisabled( // This test requires that the HostResolver cache be populated. Normal // code would have done this already, but we do it manually. HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); - rv = host_resolver_.Resolve(info, &test_hosts[i].addresses, - callback.callback(), NULL, BoundNetLog()); + rv = host_resolver_.Resolve(info, + DEFAULT_PRIORITY, + &test_hosts[i].addresses, + callback.callback(), + NULL, + BoundNetLog()); EXPECT_EQ(OK, callback.GetResult(rv)); // Setup a SpdySessionKey diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc index f0e7120a135..f791928580f 100644 --- a/chromium/net/socket/ssl_client_socket_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -30,9 +30,11 @@ //----------------------------------------------------------------------------- +namespace net { + namespace { -const net::SSLConfig kDefaultSSLConfig; +const SSLConfig kDefaultSSLConfig; // WrappedStreamSocket is a base class that wraps an existing StreamSocket, // forwarding the Socket and StreamSocket interfaces to the underlying @@ -40,33 +42,30 @@ const net::SSLConfig kDefaultSSLConfig; // This is to provide a common base class for subclasses to override specific // StreamSocket methods for testing, while still communicating with a 'real' // StreamSocket. -class WrappedStreamSocket : public net::StreamSocket { +class WrappedStreamSocket : public StreamSocket { public: - explicit WrappedStreamSocket(scoped_ptr<net::StreamSocket> transport) - : transport_(transport.Pass()) { - } + explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) + : transport_(transport.Pass()) {} virtual ~WrappedStreamSocket() {} // StreamSocket implementation: - virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { + virtual int Connect(const CompletionCallback& callback) OVERRIDE { return transport_->Connect(callback); } - virtual void Disconnect() OVERRIDE { - transport_->Disconnect(); - } + virtual void Disconnect() OVERRIDE { transport_->Disconnect(); } virtual bool IsConnected() const OVERRIDE { return transport_->IsConnected(); } virtual bool IsConnectedAndIdle() const OVERRIDE { return transport_->IsConnectedAndIdle(); } - virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { return transport_->GetPeerAddress(address); } - virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { return transport_->GetLocalAddress(address); } - virtual const net::BoundNetLog& NetLog() const OVERRIDE { + virtual const BoundNetLog& NetLog() const OVERRIDE { return transport_->NetLog(); } virtual void SetSubresourceSpeculation() OVERRIDE { @@ -84,20 +83,22 @@ class WrappedStreamSocket : public net::StreamSocket { virtual bool WasNpnNegotiated() const OVERRIDE { return transport_->WasNpnNegotiated(); } - virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { + virtual NextProto GetNegotiatedProtocol() const OVERRIDE { return transport_->GetNegotiatedProtocol(); } - virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return transport_->GetSSLInfo(ssl_info); } // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return transport_->Read(buf, buf_len, callback); } - virtual int Write(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return transport_->Write(buf, buf_len, callback); } virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { @@ -108,7 +109,7 @@ class WrappedStreamSocket : public net::StreamSocket { } protected: - scoped_ptr<net::StreamSocket> transport_; + scoped_ptr<StreamSocket> transport_; }; // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that @@ -119,12 +120,13 @@ class WrappedStreamSocket : public net::StreamSocket { // them from the TestServer. class ReadBufferingStreamSocket : public WrappedStreamSocket { public: - explicit ReadBufferingStreamSocket(scoped_ptr<net::StreamSocket> transport); + explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); virtual ~ReadBufferingStreamSocket() {} // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; // Sets the internal buffer to |size|. This must not be greater than // the largest value supplied to Read() - that is, it does not handle @@ -148,19 +150,18 @@ class ReadBufferingStreamSocket : public WrappedStreamSocket { void OnReadCompleted(int result); State state_; - scoped_refptr<net::GrowableIOBuffer> read_buffer_; + scoped_refptr<GrowableIOBuffer> read_buffer_; int buffer_size_; - scoped_refptr<net::IOBuffer> user_read_buf_; - net::CompletionCallback user_read_callback_; + scoped_refptr<IOBuffer> user_read_buf_; + CompletionCallback user_read_callback_; }; ReadBufferingStreamSocket::ReadBufferingStreamSocket( - scoped_ptr<net::StreamSocket> transport) + scoped_ptr<StreamSocket> transport) : WrappedStreamSocket(transport.Pass()), - read_buffer_(new net::GrowableIOBuffer()), - buffer_size_(0) { -} + read_buffer_(new GrowableIOBuffer()), + buffer_size_(0) {} void ReadBufferingStreamSocket::SetBufferSize(int size) { DCHECK(!user_read_buf_.get()); @@ -168,19 +169,19 @@ void ReadBufferingStreamSocket::SetBufferSize(int size) { read_buffer_->SetCapacity(size); } -int ReadBufferingStreamSocket::Read(net::IOBuffer* buf, +int ReadBufferingStreamSocket::Read(IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) { + const CompletionCallback& callback) { if (buffer_size_ == 0) return transport_->Read(buf, buf_len, callback); if (buf_len < buffer_size_) - return net::ERR_UNEXPECTED; + return ERR_UNEXPECTED; state_ = STATE_READ; user_read_buf_ = buf; - int result = DoLoop(net::OK); - if (result == net::ERR_IO_PENDING) + int result = DoLoop(OK); + if (result == ERR_IO_PENDING) user_read_callback_ = callback; else user_read_buf_ = NULL; @@ -202,10 +203,10 @@ int ReadBufferingStreamSocket::DoLoop(int result) { case STATE_NONE: default: NOTREACHED() << "Unexpected state: " << current_state; - rv = net::ERR_UNEXPECTED; + rv = ERR_UNEXPECTED; break; } - } while (rv != net::ERR_IO_PENDING && state_ != STATE_NONE); + } while (rv != ERR_IO_PENDING && state_ != STATE_NONE); return rv; } @@ -227,10 +228,11 @@ int ReadBufferingStreamSocket::DoReadComplete(int result) { read_buffer_->set_offset(read_buffer_->offset() + result); if (read_buffer_->RemainingCapacity() > 0) { state_ = STATE_READ; - return net::OK; + return OK; } - memcpy(user_read_buf_->data(), read_buffer_->StartOfBuffer(), + memcpy(user_read_buf_->data(), + read_buffer_->StartOfBuffer(), read_buffer_->capacity()); read_buffer_->set_offset(0); return read_buffer_->capacity(); @@ -238,7 +240,7 @@ int ReadBufferingStreamSocket::DoReadComplete(int result) { void ReadBufferingStreamSocket::OnReadCompleted(int result) { result = DoLoop(result); - if (result == net::ERR_IO_PENDING) + if (result == ERR_IO_PENDING) return; user_read_buf_ = NULL; @@ -252,16 +254,18 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket { virtual ~SynchronousErrorStreamSocket() {} // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; - virtual int Write(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE; + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE; // Sets the next Read() call and all future calls to return |error|. // If there is already a pending asynchronous read, the configured error // will not be returned until that asynchronous read has completed and Read() // is called again. - void SetNextReadError(net::Error error) { + void SetNextReadError(Error error) { DCHECK_GE(0, error); have_read_error_ = true; pending_read_error_ = error; @@ -271,7 +275,7 @@ class SynchronousErrorStreamSocket : public WrappedStreamSocket { // If there is already a pending asynchronous write, the configured error // will not be returned until that asynchronous write has completed and // Write() is called again. - void SetNextWriteError(net::Error error) { + void SetNextWriteError(Error error) { DCHECK_GE(0, error); have_write_error_ = true; pending_write_error_ = error; @@ -291,24 +295,21 @@ SynchronousErrorStreamSocket::SynchronousErrorStreamSocket( scoped_ptr<StreamSocket> transport) : WrappedStreamSocket(transport.Pass()), have_read_error_(false), - pending_read_error_(net::OK), + pending_read_error_(OK), have_write_error_(false), - pending_write_error_(net::OK) { -} + pending_write_error_(OK) {} -int SynchronousErrorStreamSocket::Read( - net::IOBuffer* buf, - int buf_len, - const net::CompletionCallback& callback) { +int SynchronousErrorStreamSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { if (have_read_error_) return pending_read_error_; return transport_->Read(buf, buf_len, callback); } -int SynchronousErrorStreamSocket::Write( - net::IOBuffer* buf, - int buf_len, - const net::CompletionCallback& callback) { +int SynchronousErrorStreamSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { if (have_write_error_) return pending_write_error_; return transport_->Write(buf, buf_len, callback); @@ -324,12 +325,14 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { virtual ~FakeBlockingStreamSocket() {} // Socket implementation: - virtual int Read(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return read_state_.RunWrappedFunction(buf, buf_len, callback); } - virtual int Write(net::IOBuffer* buf, int buf_len, - const net::CompletionCallback& callback) OVERRIDE { + virtual int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) OVERRIDE { return write_state_.RunWrappedFunction(buf, buf_len, callback); } @@ -350,9 +353,8 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { class BlockingState { public: // Wrapper for the underlying Socket function to call (ie: Read/Write). - typedef base::Callback< - int(net::IOBuffer*, int, - const net::CompletionCallback&)> WrappedSocketFunction; + typedef base::Callback<int(IOBuffer*, int, const CompletionCallback&)> + WrappedSocketFunction; explicit BlockingState(const WrappedSocketFunction& function); ~BlockingState() {} @@ -371,8 +373,9 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { // Performs the wrapped socket function on the underlying transport. If // configured to block via SetShouldBlock(), then |user_callback| will not // be invoked until Unblock() has been called. - int RunWrappedFunction(net::IOBuffer* buf, int len, - const net::CompletionCallback& user_callback); + int RunWrappedFunction(IOBuffer* buf, + int len, + const CompletionCallback& user_callback); private: // Handles completion from the underlying wrapped socket function. @@ -382,7 +385,7 @@ class FakeBlockingStreamSocket : public WrappedStreamSocket { bool should_block_; bool have_result_; int pending_result_; - net::CompletionCallback user_callback_; + CompletionCallback user_callback_; }; BlockingState read_state_; @@ -397,16 +400,14 @@ FakeBlockingStreamSocket::FakeBlockingStreamSocket( read_state_(base::Bind(&Socket::Read, base::Unretained(transport_.get()))), write_state_(base::Bind(&Socket::Write, - base::Unretained(transport_.get()))) { -} + base::Unretained(transport_.get()))) {} FakeBlockingStreamSocket::BlockingState::BlockingState( const WrappedSocketFunction& function) : wrapped_function_(function), should_block_(false), have_result_(false), - pending_result_(net::OK) { -} + pending_result_(OK) {} void FakeBlockingStreamSocket::BlockingState::SetShouldBlock() { DCHECK(!should_block_); @@ -429,24 +430,24 @@ void FakeBlockingStreamSocket::BlockingState::Unblock() { } int FakeBlockingStreamSocket::BlockingState::RunWrappedFunction( - net::IOBuffer* buf, + IOBuffer* buf, int len, - const net::CompletionCallback& callback) { + const CompletionCallback& callback) { // The callback to be called by the underlying transport. Either forward // directly to the user's callback if not set to block, or intercept it with // OnCompleted so that the user's callback is not invoked until Unblock() is // called. - net::CompletionCallback transport_callback = + CompletionCallback transport_callback = !should_block_ ? callback : base::Bind(&BlockingState::OnCompleted, base::Unretained(this)); int rv = wrapped_function_.Run(buf, len, transport_callback); if (should_block_) { user_callback_ = callback; // May have completed synchronously. - have_result_ = (rv != net::ERR_IO_PENDING); + have_result_ = (rv != ERR_IO_PENDING); pending_result_ = rv; - return net::ERR_IO_PENDING; + return ERR_IO_PENDING; } return rv; @@ -466,64 +467,61 @@ void FakeBlockingStreamSocket::BlockingState::OnCompleted(int result) { base::ResetAndReturn(&user_callback_).Run(result); } -// CompletionCallback that will delete the associated net::StreamSocket when +// CompletionCallback that will delete the associated StreamSocket when // the callback is invoked. -class DeleteSocketCallback : public net::TestCompletionCallbackBase { +class DeleteSocketCallback : public TestCompletionCallbackBase { public: - explicit DeleteSocketCallback(net::StreamSocket* socket) + explicit DeleteSocketCallback(StreamSocket* socket) : socket_(socket), callback_(base::Bind(&DeleteSocketCallback::OnComplete, - base::Unretained(this))) { - } + base::Unretained(this))) {} virtual ~DeleteSocketCallback() {} - const net::CompletionCallback& callback() const { return callback_; } + const CompletionCallback& callback() const { return callback_; } private: void OnComplete(int result) { - if (socket_) { - delete socket_; - socket_ = NULL; - } else { - ADD_FAILURE() << "Deleting socket twice"; - } - SetResult(result); + if (socket_) { + delete socket_; + socket_ = NULL; + } else { + ADD_FAILURE() << "Deleting socket twice"; + } + SetResult(result); } - net::StreamSocket* socket_; - net::CompletionCallback callback_; + StreamSocket* socket_; + CompletionCallback callback_; DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); }; -} // namespace - class SSLClientSocketTest : public PlatformTest { public: SSLClientSocketTest() - : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), - cert_verifier_(new net::MockCertVerifier), - transport_security_state_(new net::TransportSecurityState) { - cert_verifier_->set_default_result(net::OK); + : socket_factory_(ClientSocketFactory::GetDefaultFactory()), + cert_verifier_(new MockCertVerifier), + transport_security_state_(new TransportSecurityState) { + cert_verifier_->set_default_result(OK); context_.cert_verifier = cert_verifier_.get(); context_.transport_security_state = transport_security_state_.get(); } protected: - net::SSLClientSocket* CreateSSLClientSocket( - net::StreamSocket* transport_socket, - const net::HostPortPair& host_and_port, - const net::SSLConfig& ssl_config) { - return socket_factory_->CreateSSLClientSocket(transport_socket, - host_and_port, - ssl_config, - context_); + scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<StreamSocket> transport_socket, + const HostPortPair& host_and_port, + const SSLConfig& ssl_config) { + scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); + connection->SetSocket(transport_socket.Pass()); + return socket_factory_->CreateSSLClientSocket( + connection.Pass(), host_and_port, ssl_config, context_); } - net::ClientSocketFactory* socket_factory_; - scoped_ptr<net::MockCertVerifier> cert_verifier_; - scoped_ptr<net::TransportSecurityState> transport_security_state_; - net::SSLClientSocketContext context_; + ClientSocketFactory* socket_factory_; + scoped_ptr<MockCertVerifier> cert_verifier_; + scoped_ptr<TransportSecurityState> transport_security_state_; + SSLClientSocketContext context_; }; //----------------------------------------------------------------------------- @@ -536,45 +534,45 @@ class SSLClientSocketTest : public PlatformTest { // timeout. This means that an SSL connect end event may appear as a socket // write. static bool LogContainsSSLConnectEndEvent( - const net::CapturingNetLog::CapturedEntryList& log, int i) { - return net::LogContainsEndEvent(log, i, net::NetLog::TYPE_SSL_CONNECT) || - net::LogContainsEvent(log, i, net::NetLog::TYPE_SOCKET_BYTES_SENT, - net::NetLog::PHASE_NONE); -}; + const CapturingNetLog::CapturedEntryList& log, + int i) { + return LogContainsEndEvent(log, i, NetLog::TYPE_SSL_CONNECT) || + LogContainsEvent( + log, i, NetLog::TYPE_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); +} +; TEST_F(SSLClientSocketTest, Connect) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); log.GetEntries(&entries); EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); @@ -584,43 +582,40 @@ TEST_F(SSLClientSocketTest, Connect) { } TEST_F(SSLClientSocketTest, ConnectExpired) { - net::SpawnedTestServer::SSLOptions ssl_options( - net::SpawnedTestServer::SSLOptions::CERT_EXPIRED); - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_EXPIRED); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - cert_verifier_->set_default_result(net::ERR_CERT_DATE_INVALID); + cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::ERR_CERT_DATE_INVALID, rv); + EXPECT_EQ(ERR_CERT_DATE_INVALID, rv); // Rather than testing whether or not the underlying socket is connected, // test that the handshake has finished. This is because it may be @@ -631,43 +626,40 @@ TEST_F(SSLClientSocketTest, ConnectExpired) { } TEST_F(SSLClientSocketTest, ConnectMismatched) { - net::SpawnedTestServer::SSLOptions ssl_options( - net::SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME); - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_MISMATCHED_NAME); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - cert_verifier_->set_default_result(net::ERR_CERT_COMMON_NAME_INVALID); + cert_verifier_->set_default_result(ERR_CERT_COMMON_NAME_INVALID); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::ERR_CERT_COMMON_NAME_INVALID, rv); + EXPECT_EQ(ERR_CERT_COMMON_NAME_INVALID, rv); // Rather than testing whether or not the underlying socket is connected, // test that the handshake has finished. This is because it may be @@ -680,38 +672,35 @@ TEST_F(SSLClientSocketTest, ConnectMismatched) { // Attempt to connect to a page which requests a client certificate. It should // return an error code on connect. TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); log.GetEntries(&entries); @@ -731,9 +720,9 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { // certificate. This test may still be useful as we'll want to close // the socket on a timeout if the user takes a long time to pick a // cert. Related bug: https://bugzilla.mozilla.org/show_bug.cgi?id=542832 - net::ExpectLogContainsSomewhere( - entries, 0, net::NetLog::TYPE_SSL_CONNECT, net::NetLog::PHASE_END); - EXPECT_EQ(net::ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); + ExpectLogContainsSomewhere( + entries, 0, NetLog::TYPE_SSL_CONNECT, NetLog::PHASE_END); + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); EXPECT_FALSE(sock->IsConnected()); } @@ -742,32 +731,30 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthCertRequested) { // // TODO(davidben): Also test providing an actual certificate. TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - net::SSLConfig ssl_config = kDefaultSSLConfig; + SSLConfig ssl_config = kDefaultSSLConfig; ssl_config.send_client_cert = true; ssl_config.client_cert = NULL; - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); EXPECT_FALSE(sock->IsConnected()); @@ -775,14 +762,13 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // TODO(davidben): Add a test which requires them and verify the error. rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); log.GetEntries(&entries); EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); @@ -790,7 +776,7 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // We responded to the server's certificate request with a Certificate // message with no client certificate in it. ssl_info.client_cert_sent // should be false in this case. - net::SSLInfo ssl_info; + SSLInfo ssl_info; sock->GetSSLInfo(&ssl_info); EXPECT_FALSE(ssl_info.client_cert_sent); @@ -804,51 +790,50 @@ TEST_F(SSLClientSocketTest, ConnectClientAuthSendNullCert) { // - Server sends data unexpectedly. TEST_F(SSLClientSocketTest, Read) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); for (;;) { rv = sock->Read(buf.get(), 4096, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GE(rv, 0); @@ -862,39 +847,40 @@ TEST_F(SSLClientSocketTest, Read) { // the socket connection uncleanly. // This is a regression test for http://crbug.com/238536 TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - SynchronousErrorStreamSocket* transport = new SynchronousErrorStreamSocket( - real_transport.Pass()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<SynchronousErrorStreamSocket> transport( + new SynchronousErrorStreamSocket(real_transport.Pass())); int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); // Disable TLS False Start to avoid handshake non-determinism. - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), + SynchronousErrorStreamSocket* raw_transport = transport.get(); + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; static const int kRequestTextSize = static_cast<int>(arraysize(request_text) - 1); - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(kRequestTextSize)); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); memcpy(request_buffer->data(), request_text, kRequestTextSize); rv = callback.GetResult( @@ -902,9 +888,9 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { EXPECT_EQ(kRequestTextSize, rv); // Simulate an unclean/forcible shutdown. - transport->SetNextReadError(net::ERR_CONNECTION_RESET); + raw_transport->SetNextReadError(ERR_CONNECTION_RESET); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); // Note: This test will hang if this bug has regressed. Simply checking that // rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING is a legitimate @@ -913,7 +899,7 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { #if !defined(USE_OPENSSL) // SSLClientSocketNSS records the error exactly - EXPECT_EQ(net::ERR_CONNECTION_RESET, rv); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); #else // SSLClientSocketOpenSSL treats any errors as a simple EOF. EXPECT_EQ(0, rv); @@ -925,49 +911,51 @@ TEST_F(SSLClientSocketTest, Read_WithSynchronousError) { // intermediary terminates the socket connection uncleanly. // This is a regression test for http://crbug.com/249848 TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - // Note: |error_socket|'s ownership is handed to |transport|, but the pointer + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer // is retained in order to configure additional errors. - SynchronousErrorStreamSocket* error_socket = new SynchronousErrorStreamSocket( - real_transport.Pass()); - FakeBlockingStreamSocket* transport = new FakeBlockingStreamSocket( - scoped_ptr<net::StreamSocket>(error_socket)); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); // Disable TLS False Start to avoid handshake non-determinism. - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), ssl_config)); rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; static const int kRequestTextSize = static_cast<int>(arraysize(request_text) - 1); - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(kRequestTextSize)); + scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kRequestTextSize)); memcpy(request_buffer->data(), request_text, kRequestTextSize); // Simulate an unclean/forcible shutdown on the underlying socket. // However, simulate this error asynchronously. - error_socket->SetNextWriteError(net::ERR_CONNECTION_RESET); - transport->SetNextWriteShouldBlock(); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); + raw_transport->SetNextWriteShouldBlock(); // This write should complete synchronously, because the TLS ciphertext // can be created and placed into the outgoing buffers independent of the @@ -976,14 +964,14 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { sock->Write(request_buffer.get(), kRequestTextSize, callback.callback())); EXPECT_EQ(kRequestTextSize, rv); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); rv = sock->Read(buf.get(), 4096, callback.callback()); - EXPECT_EQ(net::ERR_IO_PENDING, rv); + EXPECT_EQ(ERR_IO_PENDING, rv); // Now unblock the outgoing request, having it fail with the connection // being reset. - transport->UnblockWrite(); + raw_transport->UnblockWrite(); // Note: This will cause an inifite loop if this bug has regressed. Simply // checking that rv != ERR_IO_PENDING is insufficient, as ERR_IO_PENDING @@ -992,7 +980,7 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { #if !defined(USE_OPENSSL) // SSLClientSocketNSS records the error exactly - EXPECT_EQ(net::ERR_CONNECTION_RESET, rv); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); #else // SSLClientSocketOpenSSL treats any errors as a simple EOF. EXPECT_EQ(0, rv); @@ -1002,38 +990,37 @@ TEST_F(SSLClientSocketTest, Write_WithSynchronousError) { // Test the full duplex mode, with Read and Write pending at the same time. // This test also serves as a regression test for http://crbug.com/29815. TEST_F(SSLClientSocketTest, Read_FullDuplex) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; // Used for everything except Write. + TestCompletionCallback callback; // Used for everything except Write. - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); // Issue a "hanging" Read first. - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); rv = sock->Read(buf.get(), 4096, callback.callback()); // We haven't written the request, so there should be no response yet. - ASSERT_EQ(net::ERR_IO_PENDING, rv); + ASSERT_EQ(ERR_IO_PENDING, rv); // Write the request. // The request is padded with a User-Agent header to a size that causes the @@ -1043,15 +1030,14 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { for (int i = 0; i < 3770; ++i) request_text.push_back('*'); request_text.append("\r\n\r\n"); - scoped_refptr<net::IOBuffer> request_buffer( - new net::StringIOBuffer(request_text)); + scoped_refptr<IOBuffer> request_buffer(new StringIOBuffer(request_text)); - net::TestCompletionCallback callback2; // Used for Write only. + TestCompletionCallback callback2; // Used for Write only. rv = sock->Write( request_buffer.get(), request_text.size(), callback2.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback2.WaitForResult(); EXPECT_EQ(static_cast<int>(request_text.size()), rv); @@ -1067,62 +1053,65 @@ TEST_F(SSLClientSocketTest, Read_FullDuplex) { // callback, the Write() callback should not be invoked. // Regression test for http://crbug.com/232633 TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - // Note: |error_socket|'s ownership is handed to |transport|, but the pointer + TestCompletionCallback callback; + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + // Note: |error_socket|'s ownership is handed to |transport|, but a pointer // is retained in order to configure additional errors. - SynchronousErrorStreamSocket* error_socket = new SynchronousErrorStreamSocket( - real_transport.Pass()); - FakeBlockingStreamSocket* transport = new FakeBlockingStreamSocket( - scoped_ptr<net::StreamSocket>(error_socket)); + scoped_ptr<SynchronousErrorStreamSocket> error_socket( + new SynchronousErrorStreamSocket(real_transport.Pass())); + SynchronousErrorStreamSocket* raw_error_socket = error_socket.get(); + scoped_ptr<FakeBlockingStreamSocket> transport( + new FakeBlockingStreamSocket(error_socket.PassAs<StreamSocket>())); + FakeBlockingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); // Disable TLS False Start to avoid handshake non-determinism. - net::SSLConfig ssl_config; + SSLConfig ssl_config; ssl_config.false_start_enabled = false; - net::SSLClientSocket* sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock = + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), + ssl_config); rv = callback.GetResult(sock->Connect(callback.callback())); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); std::string request_text = "GET / HTTP/1.1\r\nUser-Agent: long browser name "; request_text.append(20 * 1024, '*'); request_text.append("\r\n\r\n"); - scoped_refptr<net::DrainableIOBuffer> request_buffer( - new net::DrainableIOBuffer(new net::StringIOBuffer(request_text), - request_text.size())); + scoped_refptr<DrainableIOBuffer> request_buffer(new DrainableIOBuffer( + new StringIOBuffer(request_text), request_text.size())); // Simulate errors being returned from the underlying Read() and Write() ... - error_socket->SetNextReadError(net::ERR_CONNECTION_RESET); - error_socket->SetNextWriteError(net::ERR_CONNECTION_RESET); + raw_error_socket->SetNextReadError(ERR_CONNECTION_RESET); + raw_error_socket->SetNextWriteError(ERR_CONNECTION_RESET); // ... but have those errors returned asynchronously. Because the Write() will // return first, this will trigger the error. - transport->SetNextReadShouldBlock(); - transport->SetNextWriteShouldBlock(); + raw_transport->SetNextReadShouldBlock(); + raw_transport->SetNextWriteShouldBlock(); // Enqueue a Read() before calling Write(), which should "hang" due to // the ERR_IO_PENDING caused by SetReadShouldBlock() and thus return. - DeleteSocketCallback read_callback(sock); - scoped_refptr<net::IOBuffer> read_buf(new net::IOBuffer(4096)); - rv = sock->Read(read_buf.get(), 4096, read_callback.callback()); + SSLClientSocket* raw_sock = sock.get(); + DeleteSocketCallback read_callback(sock.release()); + scoped_refptr<IOBuffer> read_buf(new IOBuffer(4096)); + rv = raw_sock->Read(read_buf.get(), 4096, read_callback.callback()); // Ensure things didn't complete synchronously, otherwise |sock| is invalid. - ASSERT_EQ(net::ERR_IO_PENDING, rv); + ASSERT_EQ(ERR_IO_PENDING, rv); ASSERT_FALSE(read_callback.have_result()); #if !defined(USE_OPENSSL) @@ -1142,9 +1131,9 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { // SSLClientSocketOpenSSL::Write() will not return until all of // |request_buffer| has been written to the underlying BIO (although not // necessarily the underlying transport). - rv = callback.GetResult(sock->Write(request_buffer.get(), - request_buffer->BytesRemaining(), - callback.callback())); + rv = callback.GetResult(raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback())); ASSERT_LT(0, rv); request_buffer->DidConsume(rv); @@ -1157,22 +1146,22 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { // Attempt to write the remaining data. NSS will not be able to consume the // application data because the internal buffers are full, while OpenSSL will // return that its blocked because the underlying transport is blocked. - rv = sock->Write(request_buffer.get(), - request_buffer->BytesRemaining(), - callback.callback()); - ASSERT_EQ(net::ERR_IO_PENDING, rv); + rv = raw_sock->Write(request_buffer.get(), + request_buffer->BytesRemaining(), + callback.callback()); + ASSERT_EQ(ERR_IO_PENDING, rv); ASSERT_FALSE(callback.have_result()); // Now unblock Write(), which will invoke OnSendComplete and (eventually) // call the Read() callback, deleting the socket and thus aborting calling // the Write() callback. - transport->UnblockWrite(); + raw_transport->UnblockWrite(); rv = read_callback.WaitForResult(); #if !defined(USE_OPENSSL) // NSS records the error exactly. - EXPECT_EQ(net::ERR_CONNECTION_RESET, rv); + EXPECT_EQ(ERR_CONNECTION_RESET, rv); #else // OpenSSL treats any errors as a simple EOF. EXPECT_EQ(0, rv); @@ -1183,50 +1172,49 @@ TEST_F(SSLClientSocketTest, Read_DeleteWhilePendingFullDuplex) { } TEST_F(SSLClientSocketTest, Read_SmallChunks) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(1)); + scoped_refptr<IOBuffer> buf(new IOBuffer(1)); for (;;) { rv = sock->Read(buf.get(), 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GE(rv, 0); @@ -1236,34 +1224,36 @@ TEST_F(SSLClientSocketTest, Read_SmallChunks) { } TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; + TestCompletionCallback callback; - scoped_ptr<net::StreamSocket> real_transport(new net::TCPClientSocket( - addr, NULL, net::NetLog::Source())); - ReadBufferingStreamSocket* transport = new ReadBufferingStreamSocket( - real_transport.Pass()); + scoped_ptr<StreamSocket> real_transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); + scoped_ptr<ReadBufferingStreamSocket> transport( + new ReadBufferingStreamSocket(real_transport.Pass())); + ReadBufferingStreamSocket* raw_transport = transport.get(); int rv = callback.GetResult(transport->Connect(callback.callback())); - ASSERT_EQ(net::OK, rv); + ASSERT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), + scoped_ptr<SSLClientSocket> sock( + CreateSSLClientSocket(transport.PassAs<StreamSocket>(), + test_server.host_port_pair(), kDefaultSSLConfig)); rv = callback.GetResult(sock->Connect(callback.callback())); - ASSERT_EQ(net::OK, rv); + ASSERT_EQ(OK, rv); ASSERT_TRUE(sock->IsConnected()); const char request_text[] = "GET /ssl-many-small-records HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = callback.GetResult(sock->Write( @@ -1280,117 +1270,114 @@ TEST_F(SSLClientSocketTest, Read_ManySmallRecords) { // 15K was chosen because 15K is smaller than the 17K (max) read issued by // the SSLClientSocket implementation, and larger than the minimum amount // of ciphertext necessary to contain the 8K of plaintext requested below. - transport->SetBufferSize(15000); + raw_transport->SetBufferSize(15000); - scoped_refptr<net::IOBuffer> buffer(new net::IOBuffer(8192)); + scoped_refptr<IOBuffer> buffer(new IOBuffer(8192)); rv = callback.GetResult(sock->Read(buffer.get(), 8192, callback.callback())); ASSERT_EQ(rv, 8192); } TEST_F(SSLClientSocketTest, Read_Interrupted) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); // Do a partial read and then exit. This test should not crash! - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(512)); + scoped_refptr<IOBuffer> buf(new IOBuffer(512)); rv = sock->Read(buf.get(), 512, callback.callback()); - EXPECT_TRUE(rv > 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv > 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GT(rv, 0); } TEST_F(SSLClientSocketTest, Read_FullLogging) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - log.SetLogLevel(net::NetLog::LOG_ALL); - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + log.SetLogLevel(NetLog::LOG_ALL); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const char request_text[] = "GET / HTTP/1.0\r\n\r\n"; - scoped_refptr<net::IOBuffer> request_buffer( - new net::IOBuffer(arraysize(request_text) - 1)); + scoped_refptr<IOBuffer> request_buffer( + new IOBuffer(arraysize(request_text) - 1)); memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1); rv = sock->Write( request_buffer.get(), arraysize(request_text) - 1, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_EQ(static_cast<int>(arraysize(request_text) - 1), rv); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - size_t last_index = net::ExpectLogContainsSomewhereAfter( - entries, 5, net::NetLog::TYPE_SSL_SOCKET_BYTES_SENT, - net::NetLog::PHASE_NONE); + size_t last_index = ExpectLogContainsSomewhereAfter( + entries, 5, NetLog::TYPE_SSL_SOCKET_BYTES_SENT, NetLog::PHASE_NONE); - scoped_refptr<net::IOBuffer> buf(new net::IOBuffer(4096)); + scoped_refptr<IOBuffer> buf(new IOBuffer(4096)); for (;;) { rv = sock->Read(buf.get(), 4096, callback.callback()); - EXPECT_TRUE(rv >= 0 || rv == net::ERR_IO_PENDING); + EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); EXPECT_GE(rv, 0); @@ -1398,61 +1385,59 @@ TEST_F(SSLClientSocketTest, Read_FullLogging) { break; log.GetEntries(&entries); - last_index = net::ExpectLogContainsSomewhereAfter( - entries, last_index + 1, net::NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, - net::NetLog::PHASE_NONE); + last_index = + ExpectLogContainsSomewhereAfter(entries, + last_index + 1, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, + NetLog::PHASE_NONE); } } // Regression test for http://crbug.com/42538 TEST_F(SSLClientSocketTest, PrematureApplicationData) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; - net::TestCompletionCallback callback; + AddressList addr; + TestCompletionCallback callback; static const unsigned char application_data[] = { - 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, - 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46, - 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6, - 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36, - 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d, - 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f, - 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57, - 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82, - 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, - 0x0a - }; + 0x17, 0x03, 0x01, 0x00, 0x4a, 0x02, 0x00, 0x00, 0x46, 0x03, 0x01, 0x4b, + 0xc2, 0xf8, 0xb2, 0xc1, 0x56, 0x42, 0xb9, 0x57, 0x7f, 0xde, 0x87, 0x46, + 0xf7, 0xa3, 0x52, 0x42, 0x21, 0xf0, 0x13, 0x1c, 0x9c, 0x83, 0x88, 0xd6, + 0x93, 0x0c, 0xf6, 0x36, 0x30, 0x05, 0x7e, 0x20, 0xb5, 0xb5, 0x73, 0x36, + 0x53, 0x83, 0x0a, 0xfc, 0x17, 0x63, 0xbf, 0xa0, 0xe4, 0x42, 0x90, 0x0d, + 0x2f, 0x18, 0x6d, 0x20, 0xd8, 0x36, 0x3f, 0xfc, 0xe6, 0x01, 0xfa, 0x0f, + 0xa5, 0x75, 0x7f, 0x09, 0x00, 0x04, 0x00, 0x16, 0x03, 0x01, 0x11, 0x57, + 0x0b, 0x00, 0x11, 0x53, 0x00, 0x11, 0x50, 0x00, 0x06, 0x22, 0x30, 0x82, + 0x06, 0x1e, 0x30, 0x82, 0x05, 0x06, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, + 0x0a}; // All reads and writes complete synchronously (async=false). - net::MockRead data_reads[] = { - net::MockRead(net::SYNCHRONOUS, - reinterpret_cast<const char*>(application_data), - arraysize(application_data)), - net::MockRead(net::SYNCHRONOUS, net::OK), - }; + MockRead data_reads[] = { + MockRead(SYNCHRONOUS, + reinterpret_cast<const char*>(application_data), + arraysize(application_data)), + MockRead(SYNCHRONOUS, OK), }; - net::StaticSocketDataProvider data(data_reads, arraysize(data_reads), - NULL, 0); + StaticSocketDataProvider data(data_reads, arraysize(data_reads), NULL, 0); - net::StreamSocket* transport = - new net::MockTCPClientSocket(addr, NULL, &data); + scoped_ptr<StreamSocket> transport( + new MockTCPClientSocket(addr, NULL, &data)); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::ERR_SSL_PROTOCOL_ERROR, rv); + EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv); } TEST_F(SSLClientSocketTest, CipherSuiteDisables) { @@ -1460,46 +1445,41 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { // http://www.iana.org/assignments/tls-parameters/tls-parameters.xml, // only disabling those cipher suites that the test server actually // implements. - const uint16 kCiphersToDisable[] = { - 0x0005, // TLS_RSA_WITH_RC4_128_SHA + const uint16 kCiphersToDisable[] = {0x0005, // TLS_RSA_WITH_RC4_128_SHA }; - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; // Enable only RC4 on the test server. - ssl_options.bulk_ciphers = - net::SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + ssl_options.bulk_ciphers = SpawnedTestServer::SSLOptions::BULK_CIPHER_RC4; + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - net::SSLConfig ssl_config; + SSLConfig ssl_config; for (size_t i = 0; i < arraysize(kCiphersToDisable); ++i) ssl_config.disabled_cipher_suites.push_back(kCiphersToDisable[i]); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - ssl_config)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), ssl_config)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); // NSS has special handling that maps a handshake_failure alert received // immediately after a client_hello to be a mismatched cipher suite error, @@ -1507,17 +1487,16 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { // Secure Transport (OS X), the handshake_failure is bubbled up without any // interpretation, leading to ERR_SSL_PROTOCOL_ERROR. Either way, a failure // indicates that no cipher suite was negotiated with the test server. - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_TRUE(rv == net::ERR_SSL_VERSION_OR_CIPHER_MISMATCH || - rv == net::ERR_SSL_PROTOCOL_ERROR); + EXPECT_TRUE(rv == ERR_SSL_VERSION_OR_CIPHER_MISMATCH || + rv == ERR_SSL_PROTOCOL_ERROR); // The exact ordering differs between SSLClientSocketNSS (which issues an // extra read) and SSLClientSocketMac (which does not). Just make sure the // error appears somewhere in the log. log.GetEntries(&entries); - net::ExpectLogContainsSomewhere(entries, 0, - net::NetLog::TYPE_SSL_HANDSHAKE_ERROR, - net::NetLog::PHASE_NONE); + ExpectLogContainsSomewhere( + entries, 0, NetLog::TYPE_SSL_HANDSHAKE_ERROR, NetLog::PHASE_NONE); // We cannot test sock->IsConnected(), as the NSS implementation disconnects // the socket when it encounters an error, whereas other implementations @@ -1539,65 +1518,65 @@ TEST_F(SSLClientSocketTest, CipherSuiteDisables) { // Here we verify that such a simple ClientSocketHandle, not associated with any // client socket pool, can be destroyed safely. TEST_F(SSLClientSocketTest, ClientSocketHandleNotFromPool) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + TestCompletionCallback callback; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - net::ClientSocketHandle* socket_handle = new net::ClientSocketHandle(); - socket_handle->set_socket(transport); + scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle()); + socket_handle->SetSocket(transport.Pass()); - scoped_ptr<net::SSLClientSocket> sock( - socket_factory_->CreateSSLClientSocket( - socket_handle, test_server.host_port_pair(), kDefaultSSLConfig, - context_)); + scoped_ptr<SSLClientSocket> sock( + socket_factory_->CreateSSLClientSocket(socket_handle.Pass(), + test_server.host_port_pair(), + kDefaultSSLConfig, + context_)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); } // Verifies that SSLClientSocket::ExportKeyingMaterial return a success // code and different keying label results in different keying material. TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - net::SpawnedTestServer::kLocalhost, - base::FilePath()); + SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, + SpawnedTestServer::kLocalhost, + base::FilePath()); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; + TestCompletionCallback callback; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, NULL, net::NetLog::Source()); + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, NULL, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); const int kKeyingMaterialSize = 32; @@ -1605,23 +1584,23 @@ TEST_F(SSLClientSocketTest, ExportKeyingMaterial) { const char* kKeyingContext = ""; unsigned char client_out1[kKeyingMaterialSize]; memset(client_out1, 0, sizeof(client_out1)); - rv = sock->ExportKeyingMaterial(kKeyingLabel1, false, kKeyingContext, - client_out1, sizeof(client_out1)); - EXPECT_EQ(rv, net::OK); + rv = sock->ExportKeyingMaterial( + kKeyingLabel1, false, kKeyingContext, client_out1, sizeof(client_out1)); + EXPECT_EQ(rv, OK); const char* kKeyingLabel2 = "client-socket-test-2"; unsigned char client_out2[kKeyingMaterialSize]; memset(client_out2, 0, sizeof(client_out2)); - rv = sock->ExportKeyingMaterial(kKeyingLabel2, false, kKeyingContext, - client_out2, sizeof(client_out2)); - EXPECT_EQ(rv, net::OK); + rv = sock->ExportKeyingMaterial( + kKeyingLabel2, false, kKeyingContext, client_out2, sizeof(client_out2)); + EXPECT_EQ(rv, OK); EXPECT_NE(memcmp(client_out1, client_out2, kKeyingMaterialSize), 0); } // Verifies that SSLClientSocket::ClearSessionCache can be called without // explicit NSS initialization. TEST(SSLClientSocket, ClearSessionCache) { - net::SSLClientSocket::ClearSessionCache(); + SSLClientSocket::ClearSessionCache(); } // This tests that SSLInfo contains a properly re-constructed certificate @@ -1639,86 +1618,84 @@ TEST(SSLClientSocket, ClearSessionCache) { TEST_F(SSLClientSocketTest, VerifyReturnChainProperlyOrdered) { // By default, cause the CertVerifier to treat all certificates as // expired. - cert_verifier_->set_default_result(net::ERR_CERT_DATE_INVALID); + cert_verifier_->set_default_result(ERR_CERT_DATE_INVALID); // We will expect SSLInfo to ultimately contain this chain. - net::CertificateList certs = CreateCertificateListFromFile( - net::GetTestCertsDirectory(), "redundant-validated-chain.pem", - net::X509Certificate::FORMAT_AUTO); + CertificateList certs = + CreateCertificateListFromFile(GetTestCertsDirectory(), + "redundant-validated-chain.pem", + X509Certificate::FORMAT_AUTO); ASSERT_EQ(3U, certs.size()); - net::X509Certificate::OSCertHandles temp_intermediates; + X509Certificate::OSCertHandles temp_intermediates; temp_intermediates.push_back(certs[1]->os_cert_handle()); temp_intermediates.push_back(certs[2]->os_cert_handle()); - net::CertVerifyResult verify_result; - verify_result.verified_cert = - net::X509Certificate::CreateFromHandle(certs[0]->os_cert_handle(), - temp_intermediates); + CertVerifyResult verify_result; + verify_result.verified_cert = X509Certificate::CreateFromHandle( + certs[0]->os_cert_handle(), temp_intermediates); // Add a rule that maps the server cert (A) to the chain of A->B->C2 // rather than A->B->C. - cert_verifier_->AddResultForCert(certs[0].get(), verify_result, net::OK); + cert_verifier_->AddResultForCert(certs[0].get(), verify_result, OK); // Load and install the root for the validated chain. - scoped_refptr<net::X509Certificate> root_cert = - net::ImportCertFromFile(net::GetTestCertsDirectory(), - "redundant-validated-chain-root.pem"); - ASSERT_NE(static_cast<net::X509Certificate*>(NULL), root_cert); - net::ScopedTestRoot scoped_root(root_cert.get()); + scoped_refptr<X509Certificate> root_cert = ImportCertFromFile( + GetTestCertsDirectory(), "redundant-validated-chain-root.pem"); + ASSERT_NE(static_cast<X509Certificate*>(NULL), root_cert); + ScopedTestRoot scoped_root(root_cert.get()); // Set up a test server with CERT_CHAIN_WRONG_ROOT. - net::SpawnedTestServer::SSLOptions ssl_options( - net::SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT); - net::SpawnedTestServer test_server( - net::SpawnedTestServer::TYPE_HTTPS, ssl_options, + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_CHAIN_WRONG_ROOT); + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, + ssl_options, base::FilePath(FILE_PATH_LITERAL("net/data/ssl"))); ASSERT_TRUE(test_server.Start()); - net::AddressList addr; + AddressList addr; ASSERT_TRUE(test_server.GetAddressList(&addr)); - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - net::CapturingNetLog::CapturedEntryList entries; + CapturingNetLog::CapturedEntryList entries; log.GetEntries(&entries); - EXPECT_TRUE(net::LogContainsBeginEvent( - entries, 5, net::NetLog::TYPE_SSL_CONNECT)); - if (rv == net::ERR_IO_PENDING) + EXPECT_TRUE(LogContainsBeginEvent(entries, 5, NetLog::TYPE_SSL_CONNECT)); + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); EXPECT_TRUE(sock->IsConnected()); log.GetEntries(&entries); EXPECT_TRUE(LogContainsSSLConnectEndEvent(entries, -1)); - net::SSLInfo ssl_info; + SSLInfo ssl_info; sock->GetSSLInfo(&ssl_info); // Verify that SSLInfo contains the corrected re-constructed chain A -> B // -> C2. - const net::X509Certificate::OSCertHandles& intermediates = + const X509Certificate::OSCertHandles& intermediates = ssl_info.cert->GetIntermediateCertificates(); ASSERT_EQ(2U, intermediates.size()); - EXPECT_TRUE(net::X509Certificate::IsSameOSCert( - ssl_info.cert->os_cert_handle(), certs[0]->os_cert_handle())); - EXPECT_TRUE(net::X509Certificate::IsSameOSCert( - intermediates[0], certs[1]->os_cert_handle())); - EXPECT_TRUE(net::X509Certificate::IsSameOSCert( - intermediates[1], certs[2]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(ssl_info.cert->os_cert_handle(), + certs[0]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[0], + certs[1]->os_cert_handle())); + EXPECT_TRUE(X509Certificate::IsSameOSCert(intermediates[1], + certs[2]->os_cert_handle())); sock->Disconnect(); EXPECT_FALSE(sock->IsConnected()); @@ -1729,37 +1706,34 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { protected: // Creates a test server with the given SSLOptions, connects to it and returns // the SSLCertRequestInfo reported by the socket. - scoped_refptr<net::SSLCertRequestInfo> GetCertRequest( - net::SpawnedTestServer::SSLOptions ssl_options) { - net::SpawnedTestServer test_server(net::SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath()); + scoped_refptr<SSLCertRequestInfo> GetCertRequest( + SpawnedTestServer::SSLOptions ssl_options) { + SpawnedTestServer test_server( + SpawnedTestServer::TYPE_HTTPS, ssl_options, base::FilePath()); if (!test_server.Start()) return NULL; - net::AddressList addr; + AddressList addr; if (!test_server.GetAddressList(&addr)) return NULL; - net::TestCompletionCallback callback; - net::CapturingNetLog log; - net::StreamSocket* transport = new net::TCPClientSocket( - addr, &log, net::NetLog::Source()); + TestCompletionCallback callback; + CapturingNetLog log; + scoped_ptr<StreamSocket> transport( + new TCPClientSocket(addr, &log, NetLog::Source())); int rv = transport->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - EXPECT_EQ(net::OK, rv); + EXPECT_EQ(OK, rv); - scoped_ptr<net::SSLClientSocket> sock( - CreateSSLClientSocket(transport, test_server.host_port_pair(), - kDefaultSSLConfig)); + scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( + transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); EXPECT_FALSE(sock->IsConnected()); rv = sock->Connect(callback.callback()); - if (rv == net::ERR_IO_PENDING) + if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); - scoped_refptr<net::SSLCertRequestInfo> request_info = - new net::SSLCertRequestInfo(); + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); sock->GetSSLCertRequestInfo(request_info.get()); sock->Disconnect(); EXPECT_FALSE(sock->IsConnected()); @@ -1769,10 +1743,9 @@ class SSLClientSocketCertRequestInfoTest : public SSLClientSocketTest { }; TEST_F(SSLClientSocketCertRequestInfoTest, NoAuthorities) { - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; - scoped_refptr<net::SSLCertRequestInfo> request_info = - GetCertRequest(ssl_options); + scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options); ASSERT_TRUE(request_info.get()); EXPECT_EQ(0u, request_info->cert_authorities.size()); } @@ -1781,39 +1754,36 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { const base::FilePath::CharType kThawteFile[] = FILE_PATH_LITERAL("thawte.single.pem"); const unsigned char kThawteDN[] = { - 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, - 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a, - 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e, - 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79, - 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, - 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, - 0x53, 0x47, 0x43, 0x20, 0x43, 0x41 - }; + 0x30, 0x4c, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x5a, 0x41, 0x31, 0x25, 0x30, 0x23, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x1c, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, 0x43, 0x6f, 0x6e, + 0x73, 0x75, 0x6c, 0x74, 0x69, 0x6e, 0x67, 0x20, 0x28, 0x50, 0x74, 0x79, + 0x29, 0x20, 0x4c, 0x74, 0x64, 0x2e, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, + 0x55, 0x04, 0x03, 0x13, 0x0d, 0x54, 0x68, 0x61, 0x77, 0x74, 0x65, 0x20, + 0x53, 0x47, 0x43, 0x20, 0x43, 0x41}; const size_t kThawteLen = sizeof(kThawteDN); const base::FilePath::CharType kDiginotarFile[] = FILE_PATH_LITERAL("diginotar_root_ca.pem"); const unsigned char kDiginotarDN[] = { - 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, - 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a, - 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31, - 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69, - 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74, - 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, - 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f, - 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e, - 0x6c - }; + 0x30, 0x5f, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, + 0x02, 0x4e, 0x4c, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x0a, + 0x13, 0x09, 0x44, 0x69, 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x31, + 0x1a, 0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x11, 0x44, 0x69, + 0x67, 0x69, 0x4e, 0x6f, 0x74, 0x61, 0x72, 0x20, 0x52, 0x6f, 0x6f, 0x74, + 0x20, 0x43, 0x41, 0x31, 0x20, 0x30, 0x1e, 0x06, 0x09, 0x2a, 0x86, 0x48, + 0x86, 0xf7, 0x0d, 0x01, 0x09, 0x01, 0x16, 0x11, 0x69, 0x6e, 0x66, 0x6f, + 0x40, 0x64, 0x69, 0x67, 0x69, 0x6e, 0x6f, 0x74, 0x61, 0x72, 0x2e, 0x6e, + 0x6c}; const size_t kDiginotarLen = sizeof(kDiginotarDN); - net::SpawnedTestServer::SSLOptions ssl_options; + SpawnedTestServer::SSLOptions ssl_options; ssl_options.request_client_certificate = true; ssl_options.client_authorities.push_back( - net::GetTestClientCertsDirectory().Append(kThawteFile)); + GetTestClientCertsDirectory().Append(kThawteFile)); ssl_options.client_authorities.push_back( - net::GetTestClientCertsDirectory().Append(kDiginotarFile)); - scoped_refptr<net::SSLCertRequestInfo> request_info = - GetCertRequest(ssl_options); + GetTestClientCertsDirectory().Append(kDiginotarFile)); + scoped_refptr<SSLCertRequestInfo> request_info = GetCertRequest(ssl_options); ASSERT_TRUE(request_info.get()); ASSERT_EQ(2u, request_info->cert_authorities.size()); EXPECT_EQ(std::string(reinterpret_cast<const char*>(kThawteDN), kThawteLen), @@ -1822,3 +1792,7 @@ TEST_F(SSLClientSocketCertRequestInfoTest, TwoAuthorities) { std::string(reinterpret_cast<const char*>(kDiginotarDN), kDiginotarLen), request_info->cert_authorities[1]); } + +} // namespace + +} // namespace net diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h index 52d53cb19a2..8b607bf80cf 100644 --- a/chromium/net/socket/ssl_server_socket.h +++ b/chromium/net/socket/ssl_server_socket.h @@ -6,6 +6,7 @@ #define NET_SOCKET_SSL_SERVER_SOCKET_H_ #include "base/basictypes.h" +#include "base/memory/scoped_ptr.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" #include "net/socket/ssl_socket.h" @@ -52,8 +53,8 @@ NET_EXPORT void EnableSSLServerSockets(); // // The caller starts the SSL server handshake by calling Handshake on the // returned socket. -NET_EXPORT SSLServerSocket* CreateSSLServerSocket( - StreamSocket* socket, +NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, X509Certificate* certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc index c2681d3ee14..7e5d70118ac 100644 --- a/chromium/net/socket/ssl_server_socket_nss.cc +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -78,19 +78,20 @@ void EnableSSLServerSockets() { g_nss_ssl_server_init_singleton.Get(); } -SSLServerSocket* CreateSSLServerSocket( - StreamSocket* socket, +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, X509Certificate* cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) { DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" << "called yet!"; - return new SSLServerSocketNSS(socket, cert, key, ssl_config); + return scoped_ptr<SSLServerSocket>( + new SSLServerSocketNSS(socket.Pass(), cert, key, ssl_config)); } SSLServerSocketNSS::SSLServerSocketNSS( - StreamSocket* transport_socket, + scoped_ptr<StreamSocket> transport_socket, scoped_refptr<X509Certificate> cert, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config) @@ -100,7 +101,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( user_write_buf_len_(0), nss_fd_(NULL), nss_bufs_(NULL), - transport_socket_(transport_socket), + transport_socket_(transport_socket.Pass()), ssl_config_(ssl_config), cert_(cert), next_handshake_state_(STATE_NONE), diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h index 17a1fc38750..8bbb0e338ac 100644 --- a/chromium/net/socket/ssl_server_socket_nss.h +++ b/chromium/net/socket/ssl_server_socket_nss.h @@ -24,7 +24,7 @@ class SSLServerSocketNSS : public SSLServerSocket { public: // See comments on CreateSSLServerSocket for details of how these // parameters are used. - SSLServerSocketNSS(StreamSocket* socket, + SSLServerSocketNSS(scoped_ptr<StreamSocket> socket, scoped_refptr<X509Certificate> certificate, crypto::RSAPrivateKey* key, const SSLConfig& ssl_config); diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc index e0cf8bc0b21..c327f2caf10 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.cc +++ b/chromium/net/socket/ssl_server_socket_openssl.cc @@ -16,13 +16,13 @@ void EnableSSLServerSockets() { NOTIMPLEMENTED(); } -SSLServerSocket* CreateSSLServerSocket(StreamSocket* socket, - X509Certificate* certificate, - crypto::RSAPrivateKey* key, - const SSLConfig& ssl_config) { +scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket, + X509Certificate* certificate, + crypto::RSAPrivateKey* key, + const SSLConfig& ssl_config) { NOTIMPLEMENTED(); - delete socket; - return NULL; + return scoped_ptr<SSLServerSocket>(); } } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc index f931e2c957e..e1f7f496131 100644 --- a/chromium/net/socket/ssl_server_socket_unittest.cc +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -304,21 +304,24 @@ class SSLServerSocketTest : public PlatformTest { protected: void Initialize() { - FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); - FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); + scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); + client_connection->SetSocket( + scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); + scoped_ptr<StreamSocket> server_socket( + new FakeSocket(&channel_2_, &channel_1_)); base::FilePath certs_dir(GetTestCertsDirectory()); base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); std::string cert_der; - ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); + ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); scoped_refptr<net::X509Certificate> cert = X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); std::string key_string; - ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); + ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); std::vector<uint8> key_vector( reinterpret_cast<const uint8*>(key_string.data()), reinterpret_cast<const uint8*>(key_string.data() + @@ -344,11 +347,12 @@ class SSLServerSocketTest : public PlatformTest { net::SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); - client_socket_.reset( + client_socket_ = socket_factory_->CreateSSLClientSocket( - fake_client_socket, host_and_pair, ssl_config, context)); - server_socket_.reset(net::CreateSSLServerSocket( - fake_server_socket, cert.get(), private_key.get(), net::SSLConfig())); + client_connection.Pass(), host_and_pair, ssl_config, context); + server_socket_ = net::CreateSSLServerSocket( + server_socket.Pass(), + cert.get(), private_key.get(), net::SSLConfig()); } FakeDataChannel channel_1_; diff --git a/chromium/net/socket/stream_listen_socket.cc b/chromium/net/socket/stream_listen_socket.cc index c85c671800d..1109e7527c3 100644 --- a/chromium/net/socket/stream_listen_socket.cc +++ b/chromium/net/socket/stream_listen_socket.cc @@ -27,6 +27,7 @@ #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" using std::string; @@ -43,10 +44,8 @@ const int kReadBufSize = 4096; } // namespace #if defined(OS_WIN) -const SocketDescriptor StreamListenSocket::kInvalidSocket = INVALID_SOCKET; const int StreamListenSocket::kSocketError = SOCKET_ERROR; #elif defined(OS_POSIX) -const SocketDescriptor StreamListenSocket::kInvalidSocket = -1; const int StreamListenSocket::kSocketError = -1; #endif diff --git a/chromium/net/socket/stream_listen_socket.h b/chromium/net/socket/stream_listen_socket.h index 6f03eefaca2..9825a4ef126 100644 --- a/chromium/net/socket/stream_listen_socket.h +++ b/chromium/net/socket/stream_listen_socket.h @@ -30,21 +30,16 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" #include "net/base/net_export.h" -#include "net/socket/stream_listen_socket.h" - -#if defined(OS_POSIX) -typedef int SocketDescriptor; -#else -typedef SOCKET SocketDescriptor; -#endif +#include "net/socket/socket_descriptor.h" namespace net { class IPEndPoint; class NET_EXPORT StreamListenSocket - : public base::RefCountedThreadSafe<StreamListenSocket>, + : #if defined(OS_WIN) public base::win::ObjectWatcher::Delegate { #elif defined(OS_POSIX) @@ -52,16 +47,17 @@ class NET_EXPORT StreamListenSocket #endif public: + virtual ~StreamListenSocket(); + // TODO(erikkay): this delegate should really be split into two parts // to split up the listener from the connected socket. Perhaps this class // should be split up similarly. class Delegate { public: // |server| is the original listening Socket, connection is the new - // Socket that was created. Ownership of |connection| is transferred - // to the delegate with this call. + // Socket that was created. virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) = 0; + scoped_ptr<StreamListenSocket> connection) = 0; virtual void DidRead(StreamListenSocket* connection, const char* data, int len) = 0; @@ -78,7 +74,6 @@ class NET_EXPORT StreamListenSocket // Copies the local address to |address|. Returns a network error code. int GetLocalAddress(IPEndPoint* address); - static const SocketDescriptor kInvalidSocket; static const int kSocketError; protected: @@ -89,7 +84,6 @@ class NET_EXPORT StreamListenSocket }; StreamListenSocket(SocketDescriptor s, Delegate* del); - virtual ~StreamListenSocket(); SocketDescriptor AcceptSocket(); virtual void Accept() = 0; @@ -107,7 +101,6 @@ class NET_EXPORT StreamListenSocket Delegate* const socket_delegate_; private: - friend class base::RefCountedThreadSafe<StreamListenSocket>; friend class TransportClientSocketTest; void SendInternal(const char* bytes, int len); @@ -146,7 +139,7 @@ class NET_EXPORT StreamListenSocketFactory { virtual ~StreamListenSocketFactory() {} // Returns a new instance of StreamListenSocket or NULL if an error occurred. - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const = 0; }; diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc index dbd21056f39..22aea47778b 100644 --- a/chromium/net/socket/tcp_client_socket.cc +++ b/chromium/net/socket/tcp_client_socket.cc @@ -4,56 +4,317 @@ #include "net/socket/tcp_client_socket.h" -#include "base/file_util.h" -#include "base/files/file_path.h" +#include "base/callback_helpers.h" +#include "base/logging.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/net_util.h" namespace net { -namespace { +TCPClientSocket::TCPClientSocket(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source) + : socket_(new TCPSocket(net_log, source)), + addresses_(addresses), + current_address_index_(-1), + next_connect_state_(CONNECT_STATE_NONE), + previously_disconnected_(false) { +} + +TCPClientSocket::TCPClientSocket(scoped_ptr<TCPSocket> connected_socket, + const IPEndPoint& peer_address) + : socket_(connected_socket.Pass()), + addresses_(AddressList(peer_address)), + current_address_index_(0), + next_connect_state_(CONNECT_STATE_NONE), + previously_disconnected_(false) { + DCHECK(socket_); + + socket_->SetDefaultOptionsForClient(); + use_history_.set_was_ever_connected(); +} + +TCPClientSocket::~TCPClientSocket() { +} + +int TCPClientSocket::Bind(const IPEndPoint& address) { + if (current_address_index_ >= 0 || bind_address_) { + // Cannot bind the socket if we are already connected or connecting. + NOTREACHED(); + return ERR_UNEXPECTED; + } + + int result = OK; + if (!socket_->IsValid()) { + result = OpenSocket(address.GetFamily()); + if (result != OK) + return result; + } + + result = socket_->Bind(address); + if (result != OK) + return result; + + bind_address_.reset(new IPEndPoint(address)); + return OK; +} + +int TCPClientSocket::Connect(const CompletionCallback& callback) { + DCHECK(!callback.is_null()); + + // If connecting or already connected, then just return OK. + if (socket_->IsValid() && current_address_index_ >= 0) + return OK; + + socket_->StartLoggingMultipleConnectAttempts(addresses_); + + // We will try to connect to each address in addresses_. Start with the + // first one in the list. + next_connect_state_ = CONNECT_STATE_CONNECT; + current_address_index_ = 0; + + int rv = DoConnectLoop(OK); + if (rv == ERR_IO_PENDING) { + connect_callback_ = callback; + } else { + socket_->EndLoggingMultipleConnectAttempts(rv); + } + + return rv; +} -#if defined(OS_LINUX) +int TCPClientSocket::DoConnectLoop(int result) { + DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); -// Checks to see if the system supports TCP FastOpen. Notably, it requires -// kernel support. Additionally, this checks system configuration to ensure that -// it's enabled. -bool SystemSupportsTCPFastOpen() { - static const base::FilePath::CharType kTCPFastOpenProcFilePath[] = - "/proc/sys/net/ipv4/tcp_fastopen"; - std::string system_enabled_tcp_fastopen; - if (!file_util::ReadFileToString( - base::FilePath(kTCPFastOpenProcFilePath), - &system_enabled_tcp_fastopen)) { - return false; + int rv = result; + do { + ConnectState state = next_connect_state_; + next_connect_state_ = CONNECT_STATE_NONE; + switch (state) { + case CONNECT_STATE_CONNECT: + DCHECK_EQ(OK, rv); + rv = DoConnect(); + break; + case CONNECT_STATE_CONNECT_COMPLETE: + rv = DoConnectComplete(rv); + break; + default: + NOTREACHED() << "bad state " << state; + rv = ERR_UNEXPECTED; + break; + } + } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); + + return rv; +} + +int TCPClientSocket::DoConnect() { + DCHECK_GE(current_address_index_, 0); + DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + + const IPEndPoint& endpoint = addresses_[current_address_index_]; + + if (previously_disconnected_) { + use_history_.Reset(); + previously_disconnected_ = false; } - // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225 - // TFO_CLIENT_ENABLE is the LSB - if (system_enabled_tcp_fastopen.empty() || - (system_enabled_tcp_fastopen[0] & 0x1) == 0) { - return false; + next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; + + if (socket_->IsValid()) { + DCHECK(bind_address_); + } else { + int result = OpenSocket(endpoint.GetFamily()); + if (result != OK) + return result; + + if (bind_address_) { + result = socket_->Bind(*bind_address_); + if (result != OK) { + socket_->Close(); + return result; + } + } + } + + // |socket_| is owned by this class and the callback won't be run once + // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. + return socket_->Connect(endpoint, + base::Bind(&TCPClientSocket::DidCompleteConnect, + base::Unretained(this))); +} + +int TCPClientSocket::DoConnectComplete(int result) { + if (result == OK) { + use_history_.set_was_ever_connected(); + return OK; // Done! + } + + // Close whatever partially connected socket we currently have. + DoDisconnect(); + + // Try to fall back to the next address in the list. + if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { + next_connect_state_ = CONNECT_STATE_CONNECT; + ++current_address_index_; + return OK; + } + + // Otherwise there is nothing to fall back to, so give up. + return result; +} + +void TCPClientSocket::Disconnect() { + DoDisconnect(); + current_address_index_ = -1; + bind_address_.reset(); +} + +void TCPClientSocket::DoDisconnect() { + // If connecting or already connected, record that the socket has been + // disconnected. + previously_disconnected_ = socket_->IsValid() && current_address_index_ >= 0; + socket_->Close(); +} + +bool TCPClientSocket::IsConnected() const { + return socket_->IsConnected(); +} + +bool TCPClientSocket::IsConnectedAndIdle() const { + return socket_->IsConnectedAndIdle(); +} + +int TCPClientSocket::GetPeerAddress(IPEndPoint* address) const { + return socket_->GetPeerAddress(address); +} + +int TCPClientSocket::GetLocalAddress(IPEndPoint* address) const { + DCHECK(address); + + if (!socket_->IsValid()) { + if (bind_address_) { + *address = *bind_address_; + return OK; + } + return ERR_SOCKET_NOT_CONNECTED; } - return true; + return socket_->GetLocalAddress(address); +} + +const BoundNetLog& TCPClientSocket::NetLog() const { + return socket_->net_log(); +} + +void TCPClientSocket::SetSubresourceSpeculation() { + use_history_.set_subresource_speculation(); +} + +void TCPClientSocket::SetOmniboxSpeculation() { + use_history_.set_omnibox_speculation(); +} + +bool TCPClientSocket::WasEverUsed() const { + return use_history_.was_used_to_convey_data(); +} + +bool TCPClientSocket::UsingTCPFastOpen() const { + return socket_->UsingTCPFastOpen(); +} + +bool TCPClientSocket::WasNpnNegotiated() const { + return false; } -#else +NextProto TCPClientSocket::GetNegotiatedProtocol() const { + return kProtoUnknown; +} -bool SystemSupportsTCPFastOpen() { +bool TCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { return false; } -#endif +int TCPClientSocket::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(!callback.is_null()); + + // |socket_| is owned by this class and the callback won't be run once + // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. + CompletionCallback read_callback = base::Bind( + &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback); + int result = socket_->Read(buf, buf_len, read_callback); + if (result > 0) + use_history_.set_was_used_to_convey_data(); + + return result; +} + +int TCPClientSocket::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { + DCHECK(!callback.is_null()); + + // |socket_| is owned by this class and the callback won't be run once + // |socket_| is gone. Therefore, it is safe to use base::Unretained() here. + CompletionCallback write_callback = base::Bind( + &TCPClientSocket::DidCompleteReadWrite, base::Unretained(this), callback); + int result = socket_->Write(buf, buf_len, write_callback); + if (result > 0) + use_history_.set_was_used_to_convey_data(); + return result; } -static bool g_tcp_fastopen_enabled = false; +bool TCPClientSocket::SetReceiveBufferSize(int32 size) { + return socket_->SetReceiveBufferSize(size); +} + +bool TCPClientSocket::SetSendBufferSize(int32 size) { + return socket_->SetSendBufferSize(size); +} + +bool TCPClientSocket::SetKeepAlive(bool enable, int delay) { + return socket_->SetKeepAlive(enable, delay); +} -void SetTCPFastOpenEnabled(bool value) { - g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen(); +bool TCPClientSocket::SetNoDelay(bool no_delay) { + return socket_->SetNoDelay(no_delay); } -bool IsTCPFastOpenEnabled() { - return g_tcp_fastopen_enabled; +void TCPClientSocket::DidCompleteConnect(int result) { + DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); + DCHECK_NE(result, ERR_IO_PENDING); + DCHECK(!connect_callback_.is_null()); + + result = DoConnectLoop(result); + if (result != ERR_IO_PENDING) { + socket_->EndLoggingMultipleConnectAttempts(result); + base::ResetAndReturn(&connect_callback_).Run(result); + } +} + +void TCPClientSocket::DidCompleteReadWrite(const CompletionCallback& callback, + int result) { + if (result > 0) + use_history_.set_was_used_to_convey_data(); + + callback.Run(result); +} + +int TCPClientSocket::OpenSocket(AddressFamily family) { + DCHECK(!socket_->IsValid()); + + int result = socket_->Open(family); + if (result != OK) + return result; + + socket_->SetDefaultOptionsForClient(); + + return OK; } } // namespace net diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h index 8a2c0cd73f0..fabcbc1b39d 100644 --- a/chromium/net/socket/tcp_client_socket.h +++ b/chromium/net/socket/tcp_client_socket.h @@ -5,30 +5,116 @@ #ifndef NET_SOCKET_TCP_CLIENT_SOCKET_H_ #define NET_SOCKET_TCP_CLIENT_SOCKET_H_ -#include "build/build_config.h" +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/completion_callback.h" #include "net/base/net_export.h" - -#if defined(OS_WIN) -#include "net/socket/tcp_client_socket_win.h" -#elif defined(OS_POSIX) -#include "net/socket/tcp_client_socket_libevent.h" -#endif +#include "net/base/net_log.h" +#include "net/socket/stream_socket.h" +#include "net/socket/tcp_socket.h" namespace net { // A client socket that uses TCP as the transport layer. -#if defined(OS_WIN) -typedef TCPClientSocketWin TCPClientSocket; -#elif defined(OS_POSIX) -typedef TCPClientSocketLibevent TCPClientSocket; -#endif - -// Enable/disable experimental TCP FastOpen option. -// Not thread safe. Must be called during initialization/startup only. -NET_EXPORT void SetTCPFastOpenEnabled(bool value); - -// Check if the TCP FastOpen option is enabled. -bool IsTCPFastOpenEnabled(); +class NET_EXPORT TCPClientSocket : public StreamSocket { + public: + // The IP address(es) and port number to connect to. The TCP socket will try + // each IP address in the list until it succeeds in establishing a + // connection. + TCPClientSocket(const AddressList& addresses, + net::NetLog* net_log, + const net::NetLog::Source& source); + + // Adopts the given, connected socket and then acts as if Connect() had been + // called. This function is used by TCPServerSocket and for testing. + TCPClientSocket(scoped_ptr<TCPSocket> connected_socket, + const IPEndPoint& peer_address); + + virtual ~TCPClientSocket(); + + // Binds the socket to a local IP address and port. + int Bind(const IPEndPoint& address); + + // StreamSocket implementation. + virtual int Connect(const CompletionCallback& callback) OVERRIDE; + virtual void Disconnect() OVERRIDE; + virtual bool IsConnected() const OVERRIDE; + virtual bool IsConnectedAndIdle() const OVERRIDE; + virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual const BoundNetLog& NetLog() const OVERRIDE; + virtual void SetSubresourceSpeculation() OVERRIDE; + virtual void SetOmniboxSpeculation() OVERRIDE; + virtual bool WasEverUsed() const OVERRIDE; + virtual bool UsingTCPFastOpen() const OVERRIDE; + virtual bool WasNpnNegotiated() const OVERRIDE; + virtual NextProto GetNegotiatedProtocol() const OVERRIDE; + virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; + + // Socket implementation. + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + virtual int Read(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual int Write(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) OVERRIDE; + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; + virtual bool SetSendBufferSize(int32 size) OVERRIDE; + + virtual bool SetKeepAlive(bool enable, int delay); + virtual bool SetNoDelay(bool no_delay); + + private: + // State machine for connecting the socket. + enum ConnectState { + CONNECT_STATE_CONNECT, + CONNECT_STATE_CONNECT_COMPLETE, + CONNECT_STATE_NONE, + }; + + // State machine used by Connect(). + int DoConnectLoop(int result); + int DoConnect(); + int DoConnectComplete(int result); + + // Helper used by Disconnect(), which disconnects minus resetting + // current_address_index_ and bind_address_. + void DoDisconnect(); + + void DidCompleteConnect(int result); + void DidCompleteReadWrite(const CompletionCallback& callback, int result); + + int OpenSocket(AddressFamily family); + + scoped_ptr<TCPSocket> socket_; + + // Local IP address and port we are bound to. Set to NULL if Bind() + // wasn't called (in that case OS chooses address/port). + scoped_ptr<IPEndPoint> bind_address_; + + // The list of addresses we should try in order to establish a connection. + AddressList addresses_; + + // Where we are in above list. Set to -1 if uninitialized. + int current_address_index_; + + // External callback; called when connect is complete. + CompletionCallback connect_callback_; + + // The next state for the Connect() state machine. + ConnectState next_connect_state_; + + // This socket was previously disconnected and has not been re-connected. + bool previously_disconnected_; + + // Record of connectivity and transmissions, for use in speculative connection + // histograms. + UseHistory use_history_; + + DISALLOW_COPY_AND_ASSIGN(TCPClientSocket); +}; } // namespace net diff --git a/chromium/net/socket/tcp_client_socket_libevent.h b/chromium/net/socket/tcp_client_socket_libevent.h deleted file mode 100644 index e5a0d8deab4..00000000000 --- a/chromium/net/socket/tcp_client_socket_libevent.h +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ -#define NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ - -#include "base/memory/ref_counted.h" -#include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" -#include "net/base/address_list.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/stream_socket.h" - -namespace net { - -class BoundNetLog; - -// A client socket that uses TCP as the transport layer. -class NET_EXPORT_PRIVATE TCPClientSocketLibevent : public StreamSocket, - public base::NonThreadSafe { - public: - // The IP address(es) and port number to connect to. The TCP socket will try - // each IP address in the list until it succeeds in establishing a - // connection. - TCPClientSocketLibevent(const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source); - - virtual ~TCPClientSocketLibevent(); - - // AdoptSocket causes the given, connected socket to be adopted as a TCP - // socket. This object must not be connected. This object takes ownership of - // the given socket and then acts as if Connect() had been called. This - // function is used by TCPServerSocket() to adopt accepted connections - // and for testing. - int AdoptSocket(int socket); - - // Binds the socket to a local IP address and port. - int Bind(const IPEndPoint& address); - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback) OVERRIDE; - virtual void Disconnect() OVERRIDE; - virtual bool IsConnected() const OVERRIDE; - virtual bool IsConnectedAndIdle() const OVERRIDE; - virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual const BoundNetLog& NetLog() const OVERRIDE; - virtual void SetSubresourceSpeculation() OVERRIDE; - virtual void SetOmniboxSpeculation() OVERRIDE; - virtual bool WasEverUsed() const OVERRIDE; - virtual bool UsingTCPFastOpen() const OVERRIDE; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; - - // Socket implementation. - // Multiple outstanding requests are not supported. - // Full duplex mode (reading and writing at the same time) is supported - virtual int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) OVERRIDE; - virtual bool SetReceiveBufferSize(int32 size) OVERRIDE; - virtual bool SetSendBufferSize(int32 size) OVERRIDE; - - virtual bool SetKeepAlive(bool enable, int delay); - virtual bool SetNoDelay(bool no_delay); - - private: - // State machine for connecting the socket. - enum ConnectState { - CONNECT_STATE_CONNECT, - CONNECT_STATE_CONNECT_COMPLETE, - CONNECT_STATE_NONE, - }; - - // States that a fast open socket attempt can result in. - enum FastOpenStatus { - FAST_OPEN_STATUS_UNKNOWN, - - // The initial fast open connect attempted returned synchronously, - // indicating that we had and sent a cookie along with the initial data. - FAST_OPEN_FAST_CONNECT_RETURN, - - // The initial fast open connect attempted returned asynchronously, - // indicating that we did not have a cookie for the server. - FAST_OPEN_SLOW_CONNECT_RETURN, - - // Some other error occurred on connection, so we couldn't tell if - // fast open would have worked. - FAST_OPEN_ERROR, - - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server - // had acked the data we sent. - FAST_OPEN_SYN_DATA_ACK, - - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server - // had nacked the data we sent. - FAST_OPEN_SYN_DATA_NACK, - - // An attempt to do a fast open succeeded immediately - // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the - // socket was using fast open failed. - FAST_OPEN_SYN_DATA_FAILED, - - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) - // and we later confirmed that the server had acked initial data. This - // should never happen (we didn't send data, so it shouldn't have - // been acked). - FAST_OPEN_NO_SYN_DATA_ACK, - - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) - // and we later discovered that the server had nacked initial data. This - // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN. - FAST_OPEN_NO_SYN_DATA_NACK, - - // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) - // and our later probe for ack/nack state failed. - FAST_OPEN_NO_SYN_DATA_FAILED, - - FAST_OPEN_MAX_VALUE - }; - - class ReadWatcher : public base::MessageLoopForIO::Watcher { - public: - explicit ReadWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} - - // MessageLoopForIO::Watcher methods - - virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE; - - virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE {} - - private: - TCPClientSocketLibevent* const socket_; - - DISALLOW_COPY_AND_ASSIGN(ReadWatcher); - }; - - class WriteWatcher : public base::MessageLoopForIO::Watcher { - public: - explicit WriteWatcher(TCPClientSocketLibevent* socket) : socket_(socket) {} - - // MessageLoopForIO::Watcher implementation. - virtual void OnFileCanReadWithoutBlocking(int /* fd */) OVERRIDE {} - virtual void OnFileCanWriteWithoutBlocking(int /* fd */) OVERRIDE; - - private: - TCPClientSocketLibevent* const socket_; - - DISALLOW_COPY_AND_ASSIGN(WriteWatcher); - }; - - // State machine used by Connect(). - int DoConnectLoop(int result); - int DoConnect(); - int DoConnectComplete(int result); - - // Helper used by Disconnect(), which disconnects minus the logging and - // resetting of current_address_index_. - void DoDisconnect(); - - void DoReadCallback(int rv); - void DoWriteCallback(int rv); - void DidCompleteRead(); - void DidCompleteWrite(); - void DidCompleteConnect(); - - // Returns true if a Connect() is in progress. - bool waiting_connect() const { - return next_connect_state_ != CONNECT_STATE_NONE; - } - - // Helper to add a TCP_CONNECT (end) event to the NetLog. - void LogConnectCompletion(int net_error); - - // Internal function to write to a socket. - int InternalWrite(IOBuffer* buf, int buf_len); - - // Called when the socket is known to be in a connected state. - void RecordFastOpenStatus(); - - int socket_; - - // Local IP address and port we are bound to. Set to NULL if Bind() - // was't called (in that cases OS chooses address/port). - scoped_ptr<IPEndPoint> bind_address_; - - // Stores bound socket between Bind() and Connect() calls. - int bound_socket_; - - // The list of addresses we should try in order to establish a connection. - AddressList addresses_; - - // Where we are in above list. Set to -1 if uninitialized. - int current_address_index_; - - // The socket's libevent wrappers - base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; - base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; - - // The corresponding watchers for reads and writes. - ReadWatcher read_watcher_; - WriteWatcher write_watcher_; - - // The buffer used by OnSocketReady to retry Read requests - scoped_refptr<IOBuffer> read_buf_; - int read_buf_len_; - - // The buffer used by OnSocketReady to retry Write requests - scoped_refptr<IOBuffer> write_buf_; - int write_buf_len_; - - // External callback; called when read is complete. - CompletionCallback read_callback_; - - // External callback; called when write is complete. - CompletionCallback write_callback_; - - // The next state for the Connect() state machine. - ConnectState next_connect_state_; - - // The OS error that CONNECT_STATE_CONNECT last completed with. - int connect_os_error_; - - BoundNetLog net_log_; - - // This socket was previously disconnected and has not been re-connected. - bool previously_disconnected_; - - // Record of connectivity and transmissions, for use in speculative connection - // histograms. - UseHistory use_history_; - - // Enables experimental TCP FastOpen option. - const bool use_tcp_fastopen_; - - // True when TCP FastOpen is in use and we have done the connect. - bool tcp_fastopen_connected_; - - enum FastOpenStatus fast_open_status_; - - DISALLOW_COPY_AND_ASSIGN(TCPClientSocketLibevent); -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_CLIENT_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_client_socket_win.h b/chromium/net/socket/tcp_client_socket_win.h deleted file mode 100644 index 26c8b9feff2..00000000000 --- a/chromium/net/socket/tcp_client_socket_win.h +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ -#define NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ - -#include <winsock2.h> - -#include "base/memory/scoped_ptr.h" -#include "base/threading/non_thread_safe.h" -#include "net/base/address_list.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/stream_socket.h" - -namespace net { - -class BoundNetLog; - -class NET_EXPORT TCPClientSocketWin : public StreamSocket, - NON_EXPORTED_BASE(base::NonThreadSafe) { - public: - // The IP address(es) and port number to connect to. The TCP socket will try - // each IP address in the list until it succeeds in establishing a - // connection. - TCPClientSocketWin(const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source); - - virtual ~TCPClientSocketWin(); - - // AdoptSocket causes the given, connected socket to be adopted as a TCP - // socket. This object must not be connected. This object takes ownership of - // the given socket and then acts as if Connect() had been called. This - // function is used by TCPServerSocket() to adopt accepted connections - // and for testing. - int AdoptSocket(SOCKET socket); - - // Binds the socket to a local IP address and port. - int Bind(const IPEndPoint& address); - - // StreamSocket implementation. - virtual int Connect(const CompletionCallback& callback); - virtual void Disconnect(); - virtual bool IsConnected() const; - virtual bool IsConnectedAndIdle() const; - virtual int GetPeerAddress(IPEndPoint* address) const; - virtual int GetLocalAddress(IPEndPoint* address) const; - virtual const BoundNetLog& NetLog() const { return net_log_; } - virtual void SetSubresourceSpeculation(); - virtual void SetOmniboxSpeculation(); - virtual bool WasEverUsed() const; - virtual bool UsingTCPFastOpen() const; - virtual bool WasNpnNegotiated() const OVERRIDE; - virtual NextProto GetNegotiatedProtocol() const OVERRIDE; - virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE; - - // Socket implementation. - // Multiple outstanding requests are not supported. - // Full duplex mode (reading and writing at the same time) is supported - virtual int Read(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); - virtual int Write(IOBuffer* buf, int buf_len, - const CompletionCallback& callback); - - virtual bool SetReceiveBufferSize(int32 size); - virtual bool SetSendBufferSize(int32 size); - - virtual bool SetKeepAlive(bool enable, int delay); - virtual bool SetNoDelay(bool no_delay); - - // Perform reads in non-blocking mode instead of overlapped mode. - // Used for experiments. - static void DisableOverlappedReads(); - - private: - // State machine for connecting the socket. - enum ConnectState { - CONNECT_STATE_CONNECT, - CONNECT_STATE_CONNECT_COMPLETE, - CONNECT_STATE_NONE, - }; - - class Core; - - // State machine used by Connect(). - int DoConnectLoop(int result); - int DoConnect(); - int DoConnectComplete(int result); - - // Helper used by Disconnect(), which disconnects minus the logging and - // resetting of current_address_index_. - void DoDisconnect(); - - // Returns true if a Connect() is in progress. - bool waiting_connect() const { - return next_connect_state_ != CONNECT_STATE_NONE; - } - - // Called after Connect() has completed with |net_error|. - void LogConnectCompletion(int net_error); - - int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback); - void DoReadCallback(int rv); - void DoWriteCallback(int rv); - void DidCompleteConnect(); - void DidCompleteRead(); - void DidCompleteWrite(); - void DidSignalRead(); - - SOCKET socket_; - - // Local IP address and port we are bound to. Set to NULL if Bind() - // was't called (in that cases OS chooses address/port). - scoped_ptr<IPEndPoint> bind_address_; - - // Stores bound socket between Bind() and Connect() calls. - SOCKET bound_socket_; - - // The list of addresses we should try in order to establish a connection. - AddressList addresses_; - - // Where we are in above list. Set to -1 if uninitialized. - int current_address_index_; - - // The various states that the socket could be in. - bool waiting_read_; - bool waiting_write_; - - // The core of the socket that can live longer than the socket itself. We pass - // resources to the Windows async IO functions and we have to make sure that - // they are not destroyed while the OS still references them. - scoped_refptr<Core> core_; - - // External callback; called when connect or read is complete. - CompletionCallback read_callback_; - - // External callback; called when write is complete. - CompletionCallback write_callback_; - - // The next state for the Connect() state machine. - ConnectState next_connect_state_; - - // The OS error that CONNECT_STATE_CONNECT last completed with. - int connect_os_error_; - - BoundNetLog net_log_; - - // This socket was previously disconnected and has not been re-connected. - bool previously_disconnected_; - - // Record of connectivity and transmissions, for use in speculative connection - // histograms. - UseHistory use_history_; - - DISALLOW_COPY_AND_ASSIGN(TCPClientSocketWin); -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_CLIENT_SOCKET_WIN_H_ diff --git a/chromium/net/socket/tcp_listen_socket.cc b/chromium/net/socket/tcp_listen_socket.cc index aab2e45d0e9..223abee2cba 100644 --- a/chromium/net/socket/tcp_listen_socket.cc +++ b/chromium/net/socket/tcp_listen_socket.cc @@ -23,20 +23,21 @@ #include "build/build_config.h" #include "net/base/net_util.h" #include "net/base/winsock_init.h" +#include "net/socket/socket_descriptor.h" using std::string; namespace net { // static -scoped_refptr<TCPListenSocket> TCPListenSocket::CreateAndListen( +scoped_ptr<TCPListenSocket> TCPListenSocket::CreateAndListen( const string& ip, int port, StreamListenSocket::Delegate* del) { SocketDescriptor s = CreateAndBind(ip, port); if (s == kInvalidSocket) - return NULL; - scoped_refptr<TCPListenSocket> sock(new TCPListenSocket(s, del)); + return scoped_ptr<TCPListenSocket>(); + scoped_ptr<TCPListenSocket> sock(new TCPListenSocket(s, del)); sock->Listen(); - return sock; + return sock.Pass(); } TCPListenSocket::TCPListenSocket(SocketDescriptor s, @@ -47,11 +48,7 @@ TCPListenSocket::TCPListenSocket(SocketDescriptor s, TCPListenSocket::~TCPListenSocket() {} SocketDescriptor TCPListenSocket::CreateAndBind(const string& ip, int port) { -#if defined(OS_WIN) - EnsureWinsockInit(); -#endif - - SocketDescriptor s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + SocketDescriptor s = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (s != kInvalidSocket) { #if defined(OS_POSIX) // Allow rapid reuse. @@ -104,13 +101,13 @@ void TCPListenSocket::Accept() { SocketDescriptor conn = AcceptSocket(); if (conn == kInvalidSocket) return; - scoped_refptr<TCPListenSocket> sock( + scoped_ptr<TCPListenSocket> sock( new TCPListenSocket(conn, socket_delegate_)); // It's up to the delegate to AddRef if it wants to keep it around. #if defined(OS_POSIX) sock->WatchSocket(WAITING_READ); #endif - socket_delegate_->DidAccept(this, sock.get()); + socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>()); } TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) @@ -120,9 +117,10 @@ TCPListenSocketFactory::TCPListenSocketFactory(const string& ip, int port) TCPListenSocketFactory::~TCPListenSocketFactory() {} -scoped_refptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( +scoped_ptr<StreamListenSocket> TCPListenSocketFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { - return TCPListenSocket::CreateAndListen(ip_, port_, delegate); + return TCPListenSocket::CreateAndListen(ip_, port_, delegate) + .PassAs<StreamListenSocket>(); } } // namespace net diff --git a/chromium/net/socket/tcp_listen_socket.h b/chromium/net/socket/tcp_listen_socket.h index dbc5347e945..54a91de59bb 100644 --- a/chromium/net/socket/tcp_listen_socket.h +++ b/chromium/net/socket/tcp_listen_socket.h @@ -8,18 +8,19 @@ #include <string> #include "base/basictypes.h" -#include "base/memory/ref_counted.h" #include "net/base/net_export.h" +#include "net/socket/socket_descriptor.h" #include "net/socket/stream_listen_socket.h" namespace net { -// Implements a TCP socket. Note that this is ref counted. +// Implements a TCP socket. class NET_EXPORT TCPListenSocket : public StreamListenSocket { public: + virtual ~TCPListenSocket(); // Listen on port for the specified IP address. Use 127.0.0.1 to only // accept local connections. - static scoped_refptr<TCPListenSocket> CreateAndListen( + static scoped_ptr<TCPListenSocket> CreateAndListen( const std::string& ip, int port, StreamListenSocket::Delegate* del); // Get raw TCP socket descriptor bound to ip:port. @@ -30,10 +31,7 @@ class NET_EXPORT TCPListenSocket : public StreamListenSocket { int* port); protected: - friend class scoped_refptr<TCPListenSocket>; - TCPListenSocket(SocketDescriptor s, StreamListenSocket::Delegate* del); - virtual ~TCPListenSocket(); // Implements StreamListenSocket::Accept. virtual void Accept() OVERRIDE; @@ -49,7 +47,7 @@ class NET_EXPORT TCPListenSocketFactory : public StreamListenSocketFactory { virtual ~TCPListenSocketFactory(); // StreamListenSocketFactory overrides. - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const OVERRIDE; private: diff --git a/chromium/net/socket/tcp_listen_socket_unittest.cc b/chromium/net/socket/tcp_listen_socket_unittest.cc index d13b784cbdc..b122c6143d8 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.cc +++ b/chromium/net/socket/tcp_listen_socket_unittest.cc @@ -10,13 +10,14 @@ #include "base/bind.h" #include "base/posix/eintr_wrapper.h" #include "base/sys_byteorder.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" #include "testing/platform_test.h" namespace net { -const int TCPListenSocketTester::kTestPort = 9999; - static const int kReadBufSize = 1024; static const char kHelloWorld[] = "HELLO, WORLD"; static const int kMaxQueueSize = 20; @@ -24,7 +25,9 @@ static const char kLoopback[] = "127.0.0.1"; static const int kDefaultTimeoutMs = 5000; TCPListenSocketTester::TCPListenSocketTester() - : loop_(NULL), server_(NULL), connection_(NULL), cv_(&lock_) {} + : loop_(NULL), + cv_(&lock_), + server_port_(0) {} void TCPListenSocketTester::SetUp() { base::Thread::Options options; @@ -41,13 +44,16 @@ void TCPListenSocketTester::SetUp() { ASSERT_FALSE(server_.get() == NULL); ASSERT_EQ(ACTION_LISTEN, last_action_.type()); + int server_port = GetServerPort(); + ASSERT_GT(server_port, 0); + // verify the connect/accept and setup test_socket_ - test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - ASSERT_NE(StreamListenSocket::kInvalidSocket, test_socket_); + test_socket_ = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ASSERT_NE(kInvalidSocket, test_socket_); struct sockaddr_in client; client.sin_family = AF_INET; client.sin_addr.s_addr = inet_addr(kLoopback); - client.sin_port = base::HostToNet16(kTestPort); + client.sin_port = base::HostToNet16(server_port); int ret = HANDLE_EINTR( connect(test_socket_, reinterpret_cast<sockaddr*>(&client), sizeof(client))); @@ -113,17 +119,20 @@ int TCPListenSocketTester::ClearTestSocket() { } void TCPListenSocketTester::Shutdown() { - connection_->Release(); - connection_ = NULL; - server_->Release(); - server_ = NULL; + connection_.reset(); + server_.reset(); ReportAction(TCPListenSocketTestAction(ACTION_SHUTDOWN)); } void TCPListenSocketTester::Listen() { server_ = DoListen(); ASSERT_TRUE(server_.get()); - server_->AddRef(); + + // The server's port will be needed to open the client socket. + IPEndPoint local_address; + ASSERT_EQ(OK, server_->GetLocalAddress(&local_address)); + SetServerPort(local_address.port()); + ReportAction(TCPListenSocketTestAction(ACTION_LISTEN)); } @@ -227,10 +236,10 @@ bool TCPListenSocketTester::Send(SocketDescriptor sock, return true; } -void TCPListenSocketTester::DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) { - connection_ = connection; - connection_->AddRef(); +void TCPListenSocketTester::DidAccept( + StreamListenSocket* server, + scoped_ptr<StreamListenSocket> connection) { + connection_ = connection.Pass(); ReportAction(TCPListenSocketTestAction(ACTION_ACCEPT)); } @@ -247,11 +256,22 @@ void TCPListenSocketTester::DidClose(StreamListenSocket* sock) { TCPListenSocketTester::~TCPListenSocketTester() {} -scoped_refptr<TCPListenSocket> TCPListenSocketTester::DoListen() { - return TCPListenSocket::CreateAndListen(kLoopback, kTestPort, this); +scoped_ptr<TCPListenSocket> TCPListenSocketTester::DoListen() { + // Let the OS pick a free port. + return TCPListenSocket::CreateAndListen(kLoopback, 0, this); +} + +int TCPListenSocketTester::GetServerPort() { + base::AutoLock locked(lock_); + return server_port_; +} + +void TCPListenSocketTester::SetServerPort(int server_port) { + base::AutoLock locked(lock_); + server_port_ = server_port; } -class TCPListenSocketTest: public PlatformTest { +class TCPListenSocketTest : public PlatformTest { public: TCPListenSocketTest() { tester_ = NULL; diff --git a/chromium/net/socket/tcp_listen_socket_unittest.h b/chromium/net/socket/tcp_listen_socket_unittest.h index 048a0186705..1bc31a8d1ce 100644 --- a/chromium/net/socket/tcp_listen_socket_unittest.h +++ b/chromium/net/socket/tcp_listen_socket_unittest.h @@ -91,30 +91,37 @@ class TCPListenSocketTester : // StreamListenSocket::Delegate: virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE; + scoped_ptr<StreamListenSocket> connection) OVERRIDE; virtual void DidRead(StreamListenSocket* connection, const char* data, int len) OVERRIDE; virtual void DidClose(StreamListenSocket* sock) OVERRIDE; scoped_ptr<base::Thread> thread_; base::MessageLoopForIO* loop_; - scoped_refptr<TCPListenSocket> server_; - StreamListenSocket* connection_; + scoped_ptr<TCPListenSocket> server_; + scoped_ptr<StreamListenSocket> connection_; TCPListenSocketTestAction last_action_; SocketDescriptor test_socket_; - static const int kTestPort; - base::Lock lock_; // protects |queue_| and wraps |cv_| + base::Lock lock_; // Protects |queue_| and |server_port_|. Wraps |cv_|. base::ConditionVariable cv_; std::deque<TCPListenSocketTestAction> queue_; - protected: + private: friend class base::RefCountedThreadSafe<TCPListenSocketTester>; virtual ~TCPListenSocketTester(); - virtual scoped_refptr<TCPListenSocket> DoListen(); + virtual scoped_ptr<TCPListenSocket> DoListen(); + + // Getters/setters for |server_port_|. They use |lock_| for thread safety. + int GetServerPort(); + void SetServerPort(int server_port); + + // Port the server is using. Must have |lock_| to access. Set by Listen() on + // the server's thread. + int server_port_; }; } // namespace net diff --git a/chromium/net/socket/tcp_server_socket.cc b/chromium/net/socket/tcp_server_socket.cc new file mode 100644 index 00000000000..a25f73f6c6f --- /dev/null +++ b/chromium/net/socket/tcp_server_socket.cc @@ -0,0 +1,105 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_server_socket.h" + +#include "base/bind.h" +#include "base/bind_helpers.h" +#include "base/logging.h" +#include "net/base/net_errors.h" +#include "net/socket/tcp_client_socket.h" + +namespace net { + +TCPServerSocket::TCPServerSocket(NetLog* net_log, const NetLog::Source& source) + : socket_(net_log, source), + pending_accept_(false) { +} + +TCPServerSocket::~TCPServerSocket() { +} + +int TCPServerSocket::Listen(const IPEndPoint& address, int backlog) { + int result = socket_.Open(address.GetFamily()); + if (result != OK) + return result; + + result = socket_.SetDefaultOptionsForServer(); + if (result != OK) { + socket_.Close(); + return result; + } + + result = socket_.Bind(address); + if (result != OK) { + socket_.Close(); + return result; + } + + result = socket_.Listen(backlog); + if (result != OK) { + socket_.Close(); + return result; + } + + return OK; +} + +int TCPServerSocket::GetLocalAddress(IPEndPoint* address) const { + return socket_.GetLocalAddress(address); +} + +int TCPServerSocket::Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) { + DCHECK(socket); + DCHECK(!callback.is_null()); + + if (pending_accept_) { + NOTREACHED(); + return ERR_UNEXPECTED; + } + + // It is safe to use base::Unretained(this). |socket_| is owned by this class, + // and the callback won't be run after |socket_| is destroyed. + CompletionCallback accept_callback = + base::Bind(&TCPServerSocket::OnAcceptCompleted, base::Unretained(this), + socket, callback); + int result = socket_.Accept(&accepted_socket_, &accepted_address_, + accept_callback); + if (result != ERR_IO_PENDING) { + // |accept_callback| won't be called so we need to run + // ConvertAcceptedSocket() ourselves in order to do the conversion from + // |accepted_socket_| to |socket|. + result = ConvertAcceptedSocket(result, socket); + } else { + pending_accept_ = true; + } + + return result; +} + +int TCPServerSocket::ConvertAcceptedSocket( + int result, + scoped_ptr<StreamSocket>* output_accepted_socket) { + // Make sure the TCPSocket object is destroyed in any case. + scoped_ptr<TCPSocket> temp_accepted_socket(accepted_socket_.Pass()); + if (result != OK) + return result; + + output_accepted_socket->reset(new TCPClientSocket( + temp_accepted_socket.Pass(), accepted_address_)); + + return OK; +} + +void TCPServerSocket::OnAcceptCompleted( + scoped_ptr<StreamSocket>* output_accepted_socket, + const CompletionCallback& forward_callback, + int result) { + result = ConvertAcceptedSocket(result, output_accepted_socket); + pending_accept_ = false; + forward_callback.Run(result); +} + +} // namespace net diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h index 4970a150e8d..faff9ad826a 100644 --- a/chromium/net/socket/tcp_server_socket.h +++ b/chromium/net/socket/tcp_server_socket.h @@ -5,21 +5,48 @@ #ifndef NET_SOCKET_TCP_SERVER_SOCKET_H_ #define NET_SOCKET_TCP_SERVER_SOCKET_H_ -#include "build/build_config.h" - -#if defined(OS_WIN) -#include "net/socket/tcp_server_socket_win.h" -#elif defined(OS_POSIX) -#include "net/socket/tcp_server_socket_libevent.h" -#endif +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/server_socket.h" +#include "net/socket/tcp_socket.h" namespace net { -#if defined(OS_WIN) -typedef TCPServerSocketWin TCPServerSocket; -#elif defined(OS_POSIX) -typedef TCPServerSocketLibevent TCPServerSocket; -#endif +class NET_EXPORT_PRIVATE TCPServerSocket : public ServerSocket { + public: + TCPServerSocket(NetLog* net_log, const NetLog::Source& source); + virtual ~TCPServerSocket(); + + // net::ServerSocket implementation. + virtual int Listen(const IPEndPoint& address, int backlog) OVERRIDE; + virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; + virtual int Accept(scoped_ptr<StreamSocket>* socket, + const CompletionCallback& callback) OVERRIDE; + + private: + // Converts |accepted_socket_| and stores the result in + // |output_accepted_socket|. + // |output_accepted_socket| is untouched on failure. But |accepted_socket_| is + // set to NULL in any case. + int ConvertAcceptedSocket(int result, + scoped_ptr<StreamSocket>* output_accepted_socket); + // Completion callback for calling TCPSocket::Accept(). + void OnAcceptCompleted(scoped_ptr<StreamSocket>* output_accepted_socket, + const CompletionCallback& forward_callback, + int result); + + TCPSocket socket_; + + scoped_ptr<TCPSocket> accepted_socket_; + IPEndPoint accepted_address_; + bool pending_accept_; + + DISALLOW_COPY_AND_ASSIGN(TCPServerSocket); +}; } // namespace net diff --git a/chromium/net/socket/tcp_server_socket_libevent.cc b/chromium/net/socket/tcp_server_socket_libevent.cc deleted file mode 100644 index 38dda962f46..00000000000 --- a/chromium/net/socket/tcp_server_socket_libevent.cc +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/tcp_server_socket_libevent.h" - -#include <errno.h> -#include <fcntl.h> -#include <netdb.h> -#include <sys/socket.h> - -#include "build/build_config.h" - -#if defined(OS_POSIX) -#include <netinet/in.h> -#endif - -#include "base/posix/eintr_wrapper.h" -#include "net/base/ip_endpoint.h" -#include "net/base/net_errors.h" -#include "net/base/net_util.h" -#include "net/socket/socket_net_log_params.h" -#include "net/socket/tcp_client_socket.h" - -namespace net { - -namespace { - -const int kInvalidSocket = -1; - -} // namespace - -TCPServerSocketLibevent::TCPServerSocketLibevent( - net::NetLog* net_log, - const net::NetLog::Source& source) - : socket_(kInvalidSocket), - accept_socket_(NULL), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, - source.ToEventParametersCallback()); -} - -TCPServerSocketLibevent::~TCPServerSocketLibevent() { - if (socket_ != kInvalidSocket) - Close(); - net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); -} - -int TCPServerSocketLibevent::Listen(const IPEndPoint& address, int backlog) { - DCHECK(CalledOnValidThread()); - DCHECK_GT(backlog, 0); - DCHECK_EQ(socket_, kInvalidSocket); - - socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP); - if (socket_ < 0) { - PLOG(ERROR) << "socket() returned an error"; - return MapSystemError(errno); - } - - if (SetNonBlocking(socket_)) { - int result = MapSystemError(errno); - Close(); - return result; - } - - int result = SetSocketOptions(); - if (result != OK) { - Close(); - return result; - } - - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) { - Close(); - return ERR_ADDRESS_INVALID; - } - - result = bind(socket_, storage.addr, storage.addr_len); - if (result < 0) { - PLOG(ERROR) << "bind() returned an error"; - result = MapSystemError(errno); - Close(); - return result; - } - - result = listen(socket_, backlog); - if (result < 0) { - PLOG(ERROR) << "listen() returned an error"; - result = MapSystemError(errno); - Close(); - return result; - } - - return OK; -} - -int TCPServerSocketLibevent::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - - SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len) < 0) - return MapSystemError(errno); - if (!address->FromSockAddr(storage.addr, storage.addr_len)) - return ERR_FAILED; - - return OK; -} - -int TCPServerSocketLibevent::Accept( - scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK(socket); - DCHECK(!callback.is_null()); - DCHECK(accept_callback_.is_null()); - - net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - - int result = AcceptInternal(socket); - - if (result == ERR_IO_PENDING) { - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_READ, - &accept_socket_watcher_, this)) { - PLOG(ERROR) << "WatchFileDescriptor failed on read"; - return MapSystemError(errno); - } - - accept_socket_ = socket; - accept_callback_ = callback; - } - - return result; -} - -int TCPServerSocketLibevent::SetSocketOptions() { - // SO_REUSEADDR is useful for server sockets to bind to a recently unbound - // port. When a socket is closed, the end point changes its state to TIME_WAIT - // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer - // acknowledges its closure. For server sockets, it is usually safe to - // bind to a TIME_WAIT end point immediately, which is a widely adopted - // behavior. - // - // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to - // an end point that is already bound by another socket. To do that one must - // set SO_REUSEPORT instead. This option is not provided on Linux prior - // to 3.9. - // - // SO_REUSEPORT is provided in MacOS X and iOS. - int true_value = 1; - int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &true_value, - sizeof(true_value)); - if (rv < 0) - return MapSystemError(errno); - return OK; -} - -int TCPServerSocketLibevent::AcceptInternal( - scoped_ptr<StreamSocket>* socket) { - SockaddrStorage storage; - int new_socket = HANDLE_EINTR(accept(socket_, - storage.addr, - &storage.addr_len)); - if (new_socket < 0) { - int net_error = MapSystemError(errno); - if (net_error != ERR_IO_PENDING) - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); - return net_error; - } - - IPEndPoint address; - if (!address.FromSockAddr(storage.addr, storage.addr_len)) { - NOTREACHED(); - if (HANDLE_EINTR(close(new_socket)) < 0) - PLOG(ERROR) << "close"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); - return ERR_FAILED; - } - scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket( - AddressList(address), - net_log_.net_log(), net_log_.source())); - int adopt_result = tcp_socket->AdoptSocket(new_socket); - if (adopt_result != OK) { - if (HANDLE_EINTR(close(new_socket)) < 0) - PLOG(ERROR) << "close"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); - return adopt_result; - } - socket->reset(tcp_socket.release()); - net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, - CreateNetLogIPEndPointCallback(&address)); - return OK; -} - -void TCPServerSocketLibevent::Close() { - if (socket_ != kInvalidSocket) { - bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - if (HANDLE_EINTR(close(socket_)) < 0) - PLOG(ERROR) << "close"; - socket_ = kInvalidSocket; - } -} - -void TCPServerSocketLibevent::OnFileCanReadWithoutBlocking(int fd) { - DCHECK(CalledOnValidThread()); - - int result = AcceptInternal(accept_socket_); - if (result != ERR_IO_PENDING) { - accept_socket_ = NULL; - bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - CompletionCallback callback = accept_callback_; - accept_callback_.Reset(); - callback.Run(result); - } -} - -void TCPServerSocketLibevent::OnFileCanWriteWithoutBlocking(int fd) { - NOTREACHED(); -} - -} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_libevent.h b/chromium/net/socket/tcp_server_socket_libevent.h deleted file mode 100644 index fe69472a653..00000000000 --- a/chromium/net/socket/tcp_server_socket_libevent.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) 2011 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ -#define NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ - -#include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/server_socket.h" - -namespace net { - -class IPEndPoint; - -class NET_EXPORT_PRIVATE TCPServerSocketLibevent : - public ServerSocket, - public base::NonThreadSafe, - public base::MessageLoopForIO::Watcher { - public: - TCPServerSocketLibevent(net::NetLog* net_log, - const net::NetLog::Source& source); - virtual ~TCPServerSocketLibevent(); - - // net::ServerSocket implementation. - virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual int Accept(scoped_ptr<StreamSocket>* socket, - const CompletionCallback& callback) OVERRIDE; - - // MessageLoopForIO::Watcher implementation. - virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; - virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; - - private: - int SetSocketOptions(); - int AcceptInternal(scoped_ptr<StreamSocket>* socket); - void Close(); - - int socket_; - - base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; - - scoped_ptr<StreamSocket>* accept_socket_; - CompletionCallback accept_callback_; - - BoundNetLog net_log_; -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_SERVER_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_server_socket_win.cc b/chromium/net/socket/tcp_server_socket_win.cc deleted file mode 100644 index 0ac77be5e81..00000000000 --- a/chromium/net/socket/tcp_server_socket_win.cc +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/tcp_server_socket_win.h" - -#include <mstcpip.h> - -#include "net/base/ip_endpoint.h" -#include "net/base/net_errors.h" -#include "net/base/net_util.h" -#include "net/base/winsock_init.h" -#include "net/base/winsock_util.h" -#include "net/socket/socket_net_log_params.h" -#include "net/socket/tcp_client_socket.h" - -namespace net { - -TCPServerSocketWin::TCPServerSocketWin(net::NetLog* net_log, - const net::NetLog::Source& source) - : socket_(INVALID_SOCKET), - socket_event_(WSA_INVALID_EVENT), - accept_socket_(NULL), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { - net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, - source.ToEventParametersCallback()); - EnsureWinsockInit(); -} - -TCPServerSocketWin::~TCPServerSocketWin() { - Close(); - net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); -} - -int TCPServerSocketWin::Listen(const IPEndPoint& address, int backlog) { - DCHECK(CalledOnValidThread()); - DCHECK_GT(backlog, 0); - DCHECK_EQ(socket_, INVALID_SOCKET); - DCHECK_EQ(socket_event_, WSA_INVALID_EVENT); - - socket_event_ = WSACreateEvent(); - if (socket_event_ == WSA_INVALID_EVENT) { - PLOG(ERROR) << "WSACreateEvent()"; - return ERR_FAILED; - } - - socket_ = socket(address.GetSockAddrFamily(), SOCK_STREAM, IPPROTO_TCP); - if (socket_ == INVALID_SOCKET) { - PLOG(ERROR) << "socket() returned an error"; - return MapSystemError(WSAGetLastError()); - } - - if (SetNonBlocking(socket_)) { - int result = MapSystemError(WSAGetLastError()); - Close(); - return result; - } - - int result = SetSocketOptions(); - if (result != OK) { - Close(); - return result; - } - - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) { - Close(); - return ERR_ADDRESS_INVALID; - } - - result = bind(socket_, storage.addr, storage.addr_len); - if (result < 0) { - PLOG(ERROR) << "bind() returned an error"; - result = MapSystemError(WSAGetLastError()); - Close(); - return result; - } - - result = listen(socket_, backlog); - if (result < 0) { - PLOG(ERROR) << "listen() returned an error"; - result = MapSystemError(WSAGetLastError()); - Close(); - return result; - } - - return OK; -} - -int TCPServerSocketWin::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - - SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len)) - return MapSystemError(WSAGetLastError()); - if (!address->FromSockAddr(storage.addr, storage.addr_len)) - return ERR_FAILED; - - return OK; -} - -int TCPServerSocketWin::Accept( - scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) { - DCHECK(CalledOnValidThread()); - DCHECK(socket); - DCHECK(!callback.is_null()); - DCHECK(accept_callback_.is_null()); - - net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - - int result = AcceptInternal(socket); - - if (result == ERR_IO_PENDING) { - // Start watching - WSAEventSelect(socket_, socket_event_, FD_ACCEPT); - accept_watcher_.StartWatching(socket_event_, this); - - accept_socket_ = socket; - accept_callback_ = callback; - } - - return result; -} - -int TCPServerSocketWin::SetSocketOptions() { - // On Windows, a bound end point can be hijacked by another process by - // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE - // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the - // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another - // socket to forcibly bind to the end point until the end point is unbound. - // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE. - // MSDN: http://goo.gl/M6fjQ. - // - // Unlike on *nix, on Windows a TCP server socket can always bind to an end - // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not - // needed here. - // - // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end - // point in TIME_WAIT status. It does not have this effect for a TCP server - // socket. - - BOOL true_value = 1; - int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast<const char*>(&true_value), - sizeof(true_value)); - if (rv < 0) - return MapSystemError(errno); - return OK; -} - -int TCPServerSocketWin::AcceptInternal(scoped_ptr<StreamSocket>* socket) { - SockaddrStorage storage; - int new_socket = accept(socket_, storage.addr, &storage.addr_len); - if (new_socket < 0) { - int net_error = MapSystemError(WSAGetLastError()); - if (net_error != ERR_IO_PENDING) - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); - return net_error; - } - - IPEndPoint address; - if (!address.FromSockAddr(storage.addr, storage.addr_len)) { - NOTREACHED(); - if (closesocket(new_socket) < 0) - PLOG(ERROR) << "closesocket"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); - return ERR_FAILED; - } - scoped_ptr<TCPClientSocket> tcp_socket(new TCPClientSocket( - AddressList(address), - net_log_.net_log(), net_log_.source())); - int adopt_result = tcp_socket->AdoptSocket(new_socket); - if (adopt_result != OK) { - if (closesocket(new_socket) < 0) - PLOG(ERROR) << "closesocket"; - net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); - return adopt_result; - } - socket->reset(tcp_socket.release()); - net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, - CreateNetLogIPEndPointCallback(&address)); - return OK; -} - -void TCPServerSocketWin::Close() { - if (socket_ != INVALID_SOCKET) { - if (closesocket(socket_) < 0) - PLOG(ERROR) << "closesocket"; - socket_ = INVALID_SOCKET; - } - - if (socket_event_) { - WSACloseEvent(socket_event_); - socket_event_ = WSA_INVALID_EVENT; - } -} - -void TCPServerSocketWin::OnObjectSignaled(HANDLE object) { - WSANETWORKEVENTS ev; - if (WSAEnumNetworkEvents(socket_, socket_event_, &ev) == SOCKET_ERROR) { - PLOG(ERROR) << "WSAEnumNetworkEvents()"; - return; - } - - if (ev.lNetworkEvents & FD_ACCEPT) { - int result = AcceptInternal(accept_socket_); - if (result != ERR_IO_PENDING) { - accept_socket_ = NULL; - CompletionCallback callback = accept_callback_; - accept_callback_.Reset(); - callback.Run(result); - } - } -} - -} // namespace net diff --git a/chromium/net/socket/tcp_server_socket_win.h b/chromium/net/socket/tcp_server_socket_win.h deleted file mode 100644 index 5a1d378ad9b..00000000000 --- a/chromium/net/socket/tcp_server_socket_win.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2011 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ -#define NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ - -#include <winsock2.h> - -#include "base/memory/scoped_ptr.h" -#include "base/message_loop/message_loop.h" -#include "base/threading/non_thread_safe.h" -#include "base/win/object_watcher.h" -#include "net/base/completion_callback.h" -#include "net/base/net_log.h" -#include "net/socket/server_socket.h" - -namespace net { - -class IPEndPoint; - -class NET_EXPORT_PRIVATE TCPServerSocketWin - : public ServerSocket, - NON_EXPORTED_BASE(public base::NonThreadSafe), - public base::win::ObjectWatcher::Delegate { - public: - TCPServerSocketWin(net::NetLog* net_log, - const net::NetLog::Source& source); - ~TCPServerSocketWin(); - - // net::ServerSocket implementation. - virtual int Listen(const net::IPEndPoint& address, int backlog) OVERRIDE; - virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE; - virtual int Accept(scoped_ptr<StreamSocket>* socket, - const CompletionCallback& callback) OVERRIDE; - - // base::ObjectWatcher::Delegate implementation. - virtual void OnObjectSignaled(HANDLE object); - - private: - int SetSocketOptions(); - int AcceptInternal(scoped_ptr<StreamSocket>* socket); - void Close(); - - SOCKET socket_; - HANDLE socket_event_; - - base::win::ObjectWatcher accept_watcher_; - - scoped_ptr<StreamSocket>* accept_socket_; - CompletionCallback accept_callback_; - - BoundNetLog net_log_; -}; - -} // namespace net - -#endif // NET_SOCKET_TCP_SERVER_SOCKET_WIN_H_ diff --git a/chromium/net/socket/tcp_socket.cc b/chromium/net/socket/tcp_socket.cc new file mode 100644 index 00000000000..fd72f6b4640 --- /dev/null +++ b/chromium/net/socket/tcp_socket.cc @@ -0,0 +1,59 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_socket.h" + +#include "base/file_util.h" +#include "base/files/file_path.h" + +namespace net { + +namespace { + +#if defined(OS_LINUX) + +// Checks to see if the system supports TCP FastOpen. Notably, it requires +// kernel support. Additionally, this checks system configuration to ensure that +// it's enabled. +bool SystemSupportsTCPFastOpen() { + static const base::FilePath::CharType kTCPFastOpenProcFilePath[] = + "/proc/sys/net/ipv4/tcp_fastopen"; + std::string system_enabled_tcp_fastopen; + if (!base::ReadFileToString( + base::FilePath(kTCPFastOpenProcFilePath), + &system_enabled_tcp_fastopen)) { + return false; + } + + // As per http://lxr.linux.no/linux+v3.7.7/include/net/tcp.h#L225 + // TFO_CLIENT_ENABLE is the LSB + if (system_enabled_tcp_fastopen.empty() || + (system_enabled_tcp_fastopen[0] & 0x1) == 0) { + return false; + } + + return true; +} + +#else + +bool SystemSupportsTCPFastOpen() { + return false; +} + +#endif + +bool g_tcp_fastopen_enabled = false; + +} // namespace + +void SetTCPFastOpenEnabled(bool value) { + g_tcp_fastopen_enabled = value && SystemSupportsTCPFastOpen(); +} + +bool IsTCPFastOpenEnabled() { + return g_tcp_fastopen_enabled; +} + +} // namespace net diff --git a/chromium/net/socket/tcp_socket.h b/chromium/net/socket/tcp_socket.h new file mode 100644 index 00000000000..8b36fade758 --- /dev/null +++ b/chromium/net/socket/tcp_socket.h @@ -0,0 +1,40 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SOCKET_H_ +#define NET_SOCKET_TCP_SOCKET_H_ + +#include "build/build_config.h" +#include "net/base/net_export.h" + +#if defined(OS_WIN) +#include "net/socket/tcp_socket_win.h" +#elif defined(OS_POSIX) +#include "net/socket/tcp_socket_libevent.h" +#endif + +namespace net { + +// Enable/disable experimental TCP FastOpen option. +// Not thread safe. Must be called during initialization/startup only. +NET_EXPORT void SetTCPFastOpenEnabled(bool value); + +// Check if the TCP FastOpen option is enabled. +bool IsTCPFastOpenEnabled(); + +// TCPSocket provides a platform-independent interface for TCP sockets. +// +// It is recommended to use TCPClientSocket/TCPServerSocket instead of this +// class, unless a clear separation of client and server socket functionality is +// not suitable for your use case (e.g., a socket needs to be created and bound +// before you know whether it is a client or server socket). +#if defined(OS_WIN) +typedef TCPSocketWin TCPSocket; +#elif defined(OS_POSIX) +typedef TCPSocketLibevent TCPSocket; +#endif + +} // namespace net + +#endif // NET_SOCKET_TCP_SOCKET_H_ diff --git a/chromium/net/socket/tcp_client_socket_libevent.cc b/chromium/net/socket/tcp_socket_libevent.cc index 2f7e4b4b255..66416f70207 100644 --- a/chromium/net/socket/tcp_client_socket_libevent.cc +++ b/chromium/net/socket/tcp_socket_libevent.cc @@ -1,29 +1,27 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/socket/tcp_client_socket.h" +#include "net/socket/tcp_socket.h" #include <errno.h> #include <fcntl.h> #include <netdb.h> -#include <sys/socket.h> -#include <netinet/tcp.h> -#if defined(OS_POSIX) #include <netinet/in.h> -#endif +#include <netinet/tcp.h> +#include <sys/socket.h> +#include "base/callback_helpers.h" #include "base/logging.h" -#include "base/message_loop/message_loop.h" #include "base/metrics/histogram.h" #include "base/metrics/stats_counters.h" #include "base/posix/eintr_wrapper.h" -#include "base/strings/string_util.h" +#include "build/build_config.h" +#include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/net_util.h" #include "net/base/network_change_notifier.h" #include "net/socket/socket_net_log_params.h" @@ -37,7 +35,6 @@ namespace net { namespace { -const int kInvalidSocket = -1; const int kTCPKeepAliveSeconds = 45; // SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets @@ -46,13 +43,12 @@ const int kTCPKeepAliveSeconds = 45; // `man 7 tcp`. bool SetTCPNoDelay(int fd, bool no_delay) { int on = no_delay ? 1 : 0; - int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, - sizeof(on)); + int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)); return error == 0; } // SetTCPKeepAlive sets SO_KEEPALIVE. -bool SetTCPKeepAlive(int fd, bool enable, int delay) { +bool SetTCPKeepAlive(int fd, bool enable, int delay) { int on = enable ? 1 : 0; if (setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &on, sizeof(on))) { PLOG(ERROR) << "Failed to set SO_KEEPALIVE on fd: " << fd; @@ -73,36 +69,6 @@ bool SetTCPKeepAlive(int fd, bool enable, int delay) { return true; } -// Sets socket parameters. Returns the OS error code (or 0 on -// success). -int SetupSocket(int socket) { - if (SetNonBlocking(socket)) - return errno; - - // This mirrors the behaviour on Windows. See the comment in - // tcp_client_socket_win.cc after searching for "NODELAY". - SetTCPNoDelay(socket, true); // If SetTCPNoDelay fails, we don't care. - SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds); - - return 0; -} - -// Creates a new socket and sets default parameters for it. Returns -// the OS error code (or 0 on success). -int CreateSocket(int family, int* socket) { - *socket = ::socket(family, SOCK_STREAM, IPPROTO_TCP); - if (*socket == kInvalidSocket) - return errno; - int error = SetupSocket(*socket); - if (error) { - if (HANDLE_EINTR(close(*socket)) < 0) - PLOG(ERROR) << "close"; - *socket = kInvalidSocket; - return error; - } - return 0; -} - int MapConnectError(int os_error) { switch (os_error) { case EACCES: @@ -128,275 +94,206 @@ int MapConnectError(int os_error) { //----------------------------------------------------------------------------- -TCPClientSocketLibevent::TCPClientSocketLibevent( - const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) +TCPSocketLibevent::Watcher::Watcher( + const base::Closure& read_ready_callback, + const base::Closure& write_ready_callback) + : read_ready_callback_(read_ready_callback), + write_ready_callback_(write_ready_callback) { +} + +TCPSocketLibevent::Watcher::~Watcher() { +} + +void TCPSocketLibevent::Watcher::OnFileCanReadWithoutBlocking(int /* fd */) { + if (!read_ready_callback_.is_null()) + read_ready_callback_.Run(); + else + NOTREACHED(); +} + +void TCPSocketLibevent::Watcher::OnFileCanWriteWithoutBlocking(int /* fd */) { + if (!write_ready_callback_.is_null()) + write_ready_callback_.Run(); + else + NOTREACHED(); +} + +TCPSocketLibevent::TCPSocketLibevent(NetLog* net_log, + const NetLog::Source& source) : socket_(kInvalidSocket), - bound_socket_(kInvalidSocket), - addresses_(addresses), - current_address_index_(-1), - read_watcher_(this), - write_watcher_(this), - next_connect_state_(CONNECT_STATE_NONE), - connect_os_error_(0), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - previously_disconnected_(false), + accept_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteAccept, + base::Unretained(this)), + base::Closure()), + accept_socket_(NULL), + accept_address_(NULL), + read_watcher_(base::Bind(&TCPSocketLibevent::DidCompleteRead, + base::Unretained(this)), + base::Closure()), + write_watcher_(base::Closure(), + base::Bind(&TCPSocketLibevent::DidCompleteConnectOrWrite, + base::Unretained(this))), + read_buf_len_(0), + write_buf_len_(0), use_tcp_fastopen_(IsTCPFastOpenEnabled()), tcp_fastopen_connected_(false), - fast_open_status_(FAST_OPEN_STATUS_UNKNOWN) { + fast_open_status_(FAST_OPEN_STATUS_UNKNOWN), + waiting_connect_(false), + connect_os_error_(0), + logging_multiple_connect_attempts_(false), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, source.ToEventParametersCallback()); } -TCPClientSocketLibevent::~TCPClientSocketLibevent() { - Disconnect(); +TCPSocketLibevent::~TCPSocketLibevent() { net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); if (tcp_fastopen_connected_) { UMA_HISTOGRAM_ENUMERATION("Net.TcpFastOpenSocketConnection", fast_open_status_, FAST_OPEN_MAX_VALUE); } + Close(); } -int TCPClientSocketLibevent::AdoptSocket(int socket) { +int TCPSocketLibevent::Open(AddressFamily family) { + DCHECK(CalledOnValidThread()); DCHECK_EQ(socket_, kInvalidSocket); - int error = SetupSocket(socket); - if (error) - return MapSystemError(error); - - socket_ = socket; + socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM, + IPPROTO_TCP); + if (socket_ < 0) { + PLOG(ERROR) << "CreatePlatformSocket() returned an error"; + return MapSystemError(errno); + } - // This is to make GetPeerAddress() work. It's up to the caller ensure - // that |address_| contains a reasonable address for this - // socket. (i.e. at least match IPv4 vs IPv6!). - current_address_index_ = 0; - use_history_.set_was_ever_connected(); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(errno); + Close(); + return result; + } return OK; } -int TCPClientSocketLibevent::Bind(const IPEndPoint& address) { - if (current_address_index_ >= 0 || bind_address_.get()) { - // Cannot bind the socket if we are already bound connected or - // connecting. - return ERR_UNEXPECTED; - } - - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; +int TCPSocketLibevent::AdoptConnectedSocket(int socket, + const IPEndPoint& peer_address) { + DCHECK(CalledOnValidThread()); + DCHECK_EQ(socket_, kInvalidSocket); - // Create |bound_socket_| and try to bind it to |address|. - int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_); - if (error) - return MapSystemError(error); + socket_ = socket; - if (HANDLE_EINTR(bind(bound_socket_, storage.addr, storage.addr_len))) { - error = errno; - if (HANDLE_EINTR(close(bound_socket_)) < 0) - PLOG(ERROR) << "close"; - bound_socket_ = kInvalidSocket; - return MapSystemError(error); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(errno); + Close(); + return result; } - bind_address_.reset(new IPEndPoint(address)); + peer_address_.reset(new IPEndPoint(peer_address)); - return 0; + return OK; } -int TCPClientSocketLibevent::Connect(const CompletionCallback& callback) { +int TCPSocketLibevent::Bind(const IPEndPoint& address) { DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, kInvalidSocket); - // If already connected, then just return OK. - if (socket_ != kInvalidSocket) - return OK; - - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - - DCHECK(!waiting_connect()); - - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, - addresses_.CreateNetLogCallback()); - - // We will try to connect to each address in addresses_. Start with the - // first one in the list. - next_connect_state_ = CONNECT_STATE_CONNECT; - current_address_index_ = 0; + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_ADDRESS_INVALID; - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(!callback.is_null()); - write_callback_ = callback; - } else { - LogConnectCompletion(rv); + int result = bind(socket_, storage.addr, storage.addr_len); + if (result < 0) { + PLOG(ERROR) << "bind() returned an error"; + return MapSystemError(errno); } - return rv; -} - -int TCPClientSocketLibevent::DoConnectLoop(int result) { - DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); - - int rv = result; - do { - ConnectState state = next_connect_state_; - next_connect_state_ = CONNECT_STATE_NONE; - switch (state) { - case CONNECT_STATE_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoConnect(); - break; - case CONNECT_STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); - break; - default: - LOG(DFATAL) << "bad state"; - rv = ERR_UNEXPECTED; - break; - } - } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); - - return rv; + return OK; } -int TCPClientSocketLibevent::DoConnect() { - DCHECK_GE(current_address_index_, 0); - DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); - DCHECK_EQ(0, connect_os_error_); - - const IPEndPoint& endpoint = addresses_[current_address_index_]; +int TCPSocketLibevent::Listen(int backlog) { + DCHECK(CalledOnValidThread()); + DCHECK_GT(backlog, 0); + DCHECK_NE(socket_, kInvalidSocket); - if (previously_disconnected_) { - use_history_.Reset(); - previously_disconnected_ = false; + int result = listen(socket_, backlog); + if (result < 0) { + PLOG(ERROR) << "listen() returned an error"; + return MapSystemError(errno); } - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - CreateNetLogIPEndPointCallback(&endpoint)); + return OK; +} - next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; +int TCPSocketLibevent::Accept(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(socket); + DCHECK(address); + DCHECK(!callback.is_null()); + DCHECK(accept_callback_.is_null()); - if (bound_socket_ != kInvalidSocket) { - DCHECK(bind_address_.get()); - socket_ = bound_socket_; - bound_socket_ = kInvalidSocket; - } else { - // Create a non-blocking socket. - connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_); - if (connect_os_error_) - return MapSystemError(connect_os_error_); - - if (bind_address_.get()) { - SockaddrStorage storage; - if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - if (HANDLE_EINTR(bind(socket_, storage.addr, storage.addr_len))) - return MapSystemError(errno); - } - } + net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - // Connect the socket. - if (!use_tcp_fastopen_) { - SockaddrStorage storage; - if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; + int result = AcceptInternal(socket, address); - if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) { - // Connected without waiting! - return OK; + if (result == ERR_IO_PENDING) { + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_READ, + &accept_socket_watcher_, &accept_watcher_)) { + PLOG(ERROR) << "WatchFileDescriptor failed on read"; + return MapSystemError(errno); } - } else { - // With TCP FastOpen, we pretend that the socket is connected. - DCHECK(!tcp_fastopen_connected_); - return OK; - } - - // Check if the connect() failed synchronously. - connect_os_error_ = errno; - if (connect_os_error_ != EINPROGRESS) - return MapConnectError(connect_os_error_); - // Otherwise the connect() is going to complete asynchronously, so watch - // for its completion. - if (!base::MessageLoopForIO::current()->WatchFileDescriptor( - socket_, true, base::MessageLoopForIO::WATCH_WRITE, - &write_socket_watcher_, &write_watcher_)) { - connect_os_error_ = errno; - DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_; - return MapSystemError(connect_os_error_); + accept_socket_ = socket; + accept_address_ = address; + accept_callback_ = callback; } - return ERR_IO_PENDING; -} - -int TCPClientSocketLibevent::DoConnectComplete(int result) { - // Log the end of this attempt (and any OS error it threw). - int os_error = connect_os_error_; - connect_os_error_ = 0; - if (result != OK) { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - NetLog::IntegerCallback("os_error", os_error)); - } else { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); - } - - if (result == OK) { - write_socket_watcher_.StopWatchingFileDescriptor(); - use_history_.set_was_ever_connected(); - return OK; // Done! - } - - // Close whatever partially connected socket we currently have. - DoDisconnect(); - - // Try to fall back to the next address in the list. - if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { - next_connect_state_ = CONNECT_STATE_CONNECT; - ++current_address_index_; - return OK; - } - - // Otherwise there is nothing to fall back to, so give up. return result; } -void TCPClientSocketLibevent::Disconnect() { +int TCPSocketLibevent::Connect(const IPEndPoint& address, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, kInvalidSocket); + DCHECK(!waiting_connect_); - DoDisconnect(); - current_address_index_ = -1; - bind_address_.reset(); -} + // |peer_address_| will be non-NULL if Connect() has been called. Unless + // Close() is called to reset the internal state, a second call to Connect() + // is not allowed. + // Please note that we don't allow a second Connect() even if the previous + // Connect() has failed. Connecting the same |socket_| again after a + // connection attempt failed results in unspecified behavior according to + // POSIX. + DCHECK(!peer_address_); -void TCPClientSocketLibevent::DoDisconnect() { - if (socket_ == kInvalidSocket) - return; + if (!logging_multiple_connect_attempts_) + LogConnectBegin(AddressList(address)); - bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - ok = write_socket_watcher_.StopWatchingFileDescriptor(); - DCHECK(ok); - if (HANDLE_EINTR(close(socket_)) < 0) - PLOG(ERROR) << "close"; - socket_ = kInvalidSocket; - previously_disconnected_ = true; + peer_address_.reset(new IPEndPoint(address)); + + int rv = DoConnect(); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); + write_callback_ = callback; + waiting_connect_ = true; + } else { + DoConnectComplete(rv); + } + + return rv; } -bool TCPClientSocketLibevent::IsConnected() const { +bool TCPSocketLibevent::IsConnected() const { DCHECK(CalledOnValidThread()); - if (socket_ == kInvalidSocket || waiting_connect()) + if (socket_ == kInvalidSocket || waiting_connect_) return false; - if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { + if (use_tcp_fastopen_ && !tcp_fastopen_connected_ && peer_address_) { // With TCP FastOpen, we pretend that the socket is connected. - // This allows GetPeerAddress() to return current_ai_ as the peer - // address. Since we don't fail over to the next address if - // sendto() fails, current_ai_ is the only possible peer address. - CHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); + // This allows GetPeerAddress() to return peer_address_. return true; } @@ -411,10 +308,10 @@ bool TCPClientSocketLibevent::IsConnected() const { return true; } -bool TCPClientSocketLibevent::IsConnectedAndIdle() const { +bool TCPSocketLibevent::IsConnectedAndIdle() const { DCHECK(CalledOnValidThread()); - if (socket_ == kInvalidSocket || waiting_connect()) + if (socket_ == kInvalidSocket || waiting_connect_) return false; // TODO(wtc): should we also handle the TCP FastOpen case here, @@ -432,12 +329,12 @@ bool TCPClientSocketLibevent::IsConnectedAndIdle() const { return true; } -int TCPClientSocketLibevent::Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketLibevent::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect()); + DCHECK(!waiting_connect_); DCHECK(read_callback_.is_null()); // Synchronous operation not supported DCHECK(!callback.is_null()); @@ -447,8 +344,6 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, if (nread >= 0) { base::StatsCounter read_bytes("tcp.read_bytes"); read_bytes.Add(nread); - if (nread > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, nread, buf->data()); RecordFastOpenStatus(); @@ -474,12 +369,12 @@ int TCPClientSocketLibevent::Read(IOBuffer* buf, return ERR_IO_PENDING; } -int TCPClientSocketLibevent::Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketLibevent::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(kInvalidSocket, socket_); - DCHECK(!waiting_connect()); + DCHECK(!waiting_connect_); DCHECK(write_callback_.is_null()); // Synchronous operation not supported DCHECK(!callback.is_null()); @@ -489,8 +384,6 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, if (nwrite >= 0) { base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(nwrite); - if (nwrite > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, nwrite, buf->data()); return nwrite; @@ -515,56 +408,67 @@ int TCPClientSocketLibevent::Write(IOBuffer* buf, return ERR_IO_PENDING; } -int TCPClientSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { - int nwrite; - if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { - SockaddrStorage storage; - if (!addresses_[current_address_index_].ToSockAddr(storage.addr, - &storage.addr_len)) { - errno = EINVAL; - return -1; - } +int TCPSocketLibevent::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); - int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. -#if defined(OS_LINUX) - // sendto() will fail with EPIPE when the system doesn't support TCP Fast - // Open. Theoretically that shouldn't happen since the caller should check - // for system support on startup, but users may dynamically disable TCP Fast - // Open via sysctl. - flags |= MSG_NOSIGNAL; -#endif // defined(OS_LINUX) - nwrite = HANDLE_EINTR(sendto(socket_, - buf->data(), - buf_len, - flags, - storage.addr, - storage.addr_len)); - tcp_fastopen_connected_ = true; + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len) < 0) + return MapSystemError(errno); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_ADDRESS_INVALID; - if (nwrite < 0) { - DCHECK_NE(EPIPE, errno); + return OK; +} - // If errno == EINPROGRESS, that means the kernel didn't have a cookie - // and would block. The kernel is internally doing a connect() though. - // Remap EINPROGRESS to EAGAIN so we treat this the same as our other - // asynchronous cases. Note that the user buffer has not been copied to - // kernel space. - if (errno == EINPROGRESS) { - errno = EAGAIN; - fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN; - } else { - fast_open_status_ = FAST_OPEN_ERROR; - } - } else { - fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN; - } - } else { - nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len)); - } - return nwrite; +int TCPSocketLibevent::GetPeerAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = *peer_address_; + return OK; } -bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) { +int TCPSocketLibevent::SetDefaultOptionsForServer() { + DCHECK(CalledOnValidThread()); + return SetAddressReuse(true); +} + +void TCPSocketLibevent::SetDefaultOptionsForClient() { + DCHECK(CalledOnValidThread()); + + // This mirrors the behaviour on Windows. See the comment in + // tcp_socket_win.cc after searching for "NODELAY". + SetTCPNoDelay(socket_, true); // If SetTCPNoDelay fails, we don't care. + SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds); +} + +int TCPSocketLibevent::SetAddressReuse(bool allow) { + DCHECK(CalledOnValidThread()); + + // SO_REUSEADDR is useful for server sockets to bind to a recently unbound + // port. When a socket is closed, the end point changes its state to TIME_WAIT + // and wait for 2 MSL (maximum segment lifetime) to ensure the remote peer + // acknowledges its closure. For server sockets, it is usually safe to + // bind to a TIME_WAIT end point immediately, which is a widely adopted + // behavior. + // + // Note that on *nix, SO_REUSEADDR does not enable the TCP socket to bind to + // an end point that is already bound by another socket. To do that one must + // set SO_REUSEPORT instead. This option is not provided on Linux prior + // to 3.9. + // + // SO_REUSEPORT is provided in MacOS X and iOS. + int boolean_value = allow ? 1 : 0; + int rv = setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &boolean_value, + sizeof(boolean_value)); + if (rv < 0) + return MapSystemError(errno); + return OK; +} + +bool TCPSocketLibevent::SetReceiveBufferSize(int32 size) { DCHECK(CalledOnValidThread()); int rv = setsockopt(socket_, SOL_SOCKET, SO_RCVBUF, reinterpret_cast<const char*>(&size), @@ -573,7 +477,7 @@ bool TCPClientSocketLibevent::SetReceiveBufferSize(int32 size) { return rv == 0; } -bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) { +bool TCPSocketLibevent::SetSendBufferSize(int32 size) { DCHECK(CalledOnValidThread()); int rv = setsockopt(socket_, SOL_SOCKET, SO_SNDBUF, reinterpret_cast<const char*>(&size), @@ -582,31 +486,180 @@ bool TCPClientSocketLibevent::SetSendBufferSize(int32 size) { return rv == 0; } -bool TCPClientSocketLibevent::SetKeepAlive(bool enable, int delay) { - int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_; - return SetTCPKeepAlive(socket, enable, delay); +bool TCPSocketLibevent::SetKeepAlive(bool enable, int delay) { + DCHECK(CalledOnValidThread()); + return SetTCPKeepAlive(socket_, enable, delay); } -bool TCPClientSocketLibevent::SetNoDelay(bool no_delay) { - int socket = socket_ != kInvalidSocket ? socket_ : bound_socket_; - return SetTCPNoDelay(socket, no_delay); +bool TCPSocketLibevent::SetNoDelay(bool no_delay) { + DCHECK(CalledOnValidThread()); + return SetTCPNoDelay(socket_, no_delay); } -void TCPClientSocketLibevent::ReadWatcher::OnFileCanReadWithoutBlocking(int) { - socket_->RecordFastOpenStatus(); - if (!socket_->read_callback_.is_null()) - socket_->DidCompleteRead(); +void TCPSocketLibevent::Close() { + DCHECK(CalledOnValidThread()); + + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = read_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + ok = write_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + + if (socket_ != kInvalidSocket) { + if (HANDLE_EINTR(close(socket_)) < 0) + PLOG(ERROR) << "close"; + socket_ = kInvalidSocket; + } + + if (!accept_callback_.is_null()) { + accept_socket_ = NULL; + accept_address_ = NULL; + accept_callback_.Reset(); + } + + if (!read_callback_.is_null()) { + read_buf_ = NULL; + read_buf_len_ = 0; + read_callback_.Reset(); + } + + if (!write_callback_.is_null()) { + write_buf_ = NULL; + write_buf_len_ = 0; + write_callback_.Reset(); + } + + tcp_fastopen_connected_ = false; + fast_open_status_ = FAST_OPEN_STATUS_UNKNOWN; + waiting_connect_ = false; + peer_address_.reset(); + connect_os_error_ = 0; +} + +bool TCPSocketLibevent::UsingTCPFastOpen() const { + return use_tcp_fastopen_; +} + +void TCPSocketLibevent::StartLoggingMultipleConnectAttempts( + const AddressList& addresses) { + if (!logging_multiple_connect_attempts_) { + logging_multiple_connect_attempts_ = true; + LogConnectBegin(addresses); + } else { + NOTREACHED(); + } +} + +void TCPSocketLibevent::EndLoggingMultipleConnectAttempts(int net_error) { + if (logging_multiple_connect_attempts_) { + LogConnectEnd(net_error); + logging_multiple_connect_attempts_ = false; + } else { + NOTREACHED(); + } +} + +int TCPSocketLibevent::AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address) { + SockaddrStorage storage; + int new_socket = HANDLE_EINTR(accept(socket_, + storage.addr, + &storage.addr_len)); + if (new_socket < 0) { + int net_error = MapSystemError(errno); + if (net_error != ERR_IO_PENDING) + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); + return net_error; + } + + IPEndPoint ip_end_point; + if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) { + NOTREACHED(); + if (HANDLE_EINTR(close(new_socket)) < 0) + PLOG(ERROR) << "close"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, + ERR_ADDRESS_INVALID); + return ERR_ADDRESS_INVALID; + } + scoped_ptr<TCPSocketLibevent> tcp_socket(new TCPSocketLibevent( + net_log_.net_log(), net_log_.source())); + int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point); + if (adopt_result != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); + return adopt_result; + } + *socket = tcp_socket.Pass(); + *address = ip_end_point; + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(&ip_end_point)); + return OK; +} + +int TCPSocketLibevent::DoConnect() { + DCHECK_EQ(0, connect_os_error_); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(peer_address_.get())); + + // Connect the socket. + if (!use_tcp_fastopen_) { + SockaddrStorage storage; + if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + + if (!HANDLE_EINTR(connect(socket_, storage.addr, storage.addr_len))) { + // Connected without waiting! + return OK; + } + } else { + // With TCP FastOpen, we pretend that the socket is connected. + DCHECK(!tcp_fastopen_connected_); + return OK; + } + + // Check if the connect() failed synchronously. + connect_os_error_ = errno; + if (connect_os_error_ != EINPROGRESS) + return MapConnectError(connect_os_error_); + + // Otherwise the connect() is going to complete asynchronously, so watch + // for its completion. + if (!base::MessageLoopForIO::current()->WatchFileDescriptor( + socket_, true, base::MessageLoopForIO::WATCH_WRITE, + &write_socket_watcher_, &write_watcher_)) { + connect_os_error_ = errno; + DVLOG(1) << "WatchFileDescriptor failed: " << connect_os_error_; + return MapSystemError(connect_os_error_); + } + + return ERR_IO_PENDING; } -void TCPClientSocketLibevent::WriteWatcher::OnFileCanWriteWithoutBlocking(int) { - if (socket_->waiting_connect()) { - socket_->DidCompleteConnect(); - } else if (!socket_->write_callback_.is_null()) { - socket_->DidCompleteWrite(); +void TCPSocketLibevent::DoConnectComplete(int result) { + // Log the end of this attempt (and any OS error it threw). + int os_error = connect_os_error_; + connect_os_error_ = 0; + if (result != OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + NetLog::IntegerCallback("os_error", os_error)); + } else { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); } + + if (!logging_multiple_connect_attempts_) + LogConnectEnd(result); } -void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { +void TCPSocketLibevent::LogConnectBegin(const AddressList& addresses) { + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + addresses.CreateNetLogCallback()); +} + +void TCPSocketLibevent::LogConnectEnd(int net_error) { if (net_error == OK) UpdateConnectionTypeHistograms(CONNECTION_ANY); @@ -629,50 +682,11 @@ void TCPClientSocketLibevent::LogConnectCompletion(int net_error) { storage.addr_len)); } -void TCPClientSocketLibevent::DoReadCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(!read_callback_.is_null()); - - // since Run may result in Read being called, clear read_callback_ up front. - CompletionCallback c = read_callback_; - read_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketLibevent::DoWriteCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(!write_callback_.is_null()); - - // since Run may result in Write being called, clear write_callback_ up front. - CompletionCallback c = write_callback_; - write_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketLibevent::DidCompleteConnect() { - DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); - - // Get the error that connect() completed with. - int os_error = 0; - socklen_t len = sizeof(os_error); - if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0) - os_error = errno; - - // TODO(eroman): Is this check really necessary? - if (os_error == EINPROGRESS || os_error == EALREADY) { - NOTREACHED(); // This indicates a bug in libevent or our code. +void TCPSocketLibevent::DidCompleteRead() { + RecordFastOpenStatus(); + if (read_callback_.is_null()) return; - } - connect_os_error_ = os_error; - int rv = DoConnectLoop(MapConnectError(os_error)); - if (rv != ERR_IO_PENDING) { - LogConnectCompletion(rv); - DoWriteCallback(rv); - } -} - -void TCPClientSocketLibevent::DidCompleteRead() { int bytes_transferred; bytes_transferred = HANDLE_EINTR(read(socket_, read_buf_->data(), read_buf_len_)); @@ -682,8 +696,6 @@ void TCPClientSocketLibevent::DidCompleteRead() { result = bytes_transferred; base::StatsCounter read_bytes("tcp.read_bytes"); read_bytes.Add(bytes_transferred); - if (bytes_transferred > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, result, read_buf_->data()); } else { @@ -699,11 +711,14 @@ void TCPClientSocketLibevent::DidCompleteRead() { read_buf_len_ = 0; bool ok = read_socket_watcher_.StopWatchingFileDescriptor(); DCHECK(ok); - DoReadCallback(result); + base::ResetAndReturn(&read_callback_).Run(result); } } -void TCPClientSocketLibevent::DidCompleteWrite() { +void TCPSocketLibevent::DidCompleteWrite() { + if (write_callback_.is_null()) + return; + int bytes_transferred; bytes_transferred = HANDLE_EINTR(write(socket_, write_buf_->data(), write_buf_len_)); @@ -713,8 +728,6 @@ void TCPClientSocketLibevent::DidCompleteWrite() { result = bytes_transferred; base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(bytes_transferred); - if (bytes_transferred > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, result, write_buf_->data()); } else { @@ -729,40 +742,100 @@ void TCPClientSocketLibevent::DidCompleteWrite() { write_buf_ = NULL; write_buf_len_ = 0; write_socket_watcher_.StopWatchingFileDescriptor(); - DoWriteCallback(result); + base::ResetAndReturn(&write_callback_).Run(result); } } -int TCPClientSocketLibevent::GetPeerAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - if (!IsConnected()) - return ERR_SOCKET_NOT_CONNECTED; - *address = addresses_[current_address_index_]; - return OK; +void TCPSocketLibevent::DidCompleteConnect() { + DCHECK(waiting_connect_); + + // Get the error that connect() completed with. + int os_error = 0; + socklen_t len = sizeof(os_error); + if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &os_error, &len) < 0) + os_error = errno; + + int result = MapConnectError(os_error); + connect_os_error_ = os_error; + if (result != ERR_IO_PENDING) { + DoConnectComplete(result); + waiting_connect_ = false; + write_socket_watcher_.StopWatchingFileDescriptor(); + base::ResetAndReturn(&write_callback_).Run(result); + } +} + +void TCPSocketLibevent::DidCompleteConnectOrWrite() { + if (waiting_connect_) + DidCompleteConnect(); + else + DidCompleteWrite(); } -int TCPClientSocketLibevent::GetLocalAddress(IPEndPoint* address) const { +void TCPSocketLibevent::DidCompleteAccept() { DCHECK(CalledOnValidThread()); - DCHECK(address); - if (socket_ == kInvalidSocket) { - if (bind_address_.get()) { - *address = *bind_address_; - return OK; - } - return ERR_SOCKET_NOT_CONNECTED; + + int result = AcceptInternal(accept_socket_, accept_address_); + if (result != ERR_IO_PENDING) { + accept_socket_ = NULL; + accept_address_ = NULL; + bool ok = accept_socket_watcher_.StopWatchingFileDescriptor(); + DCHECK(ok); + CompletionCallback callback = accept_callback_; + accept_callback_.Reset(); + callback.Run(result); } +} - SockaddrStorage storage; - if (getsockname(socket_, storage.addr, &storage.addr_len)) - return MapSystemError(errno); - if (!address->FromSockAddr(storage.addr, storage.addr_len)) - return ERR_FAILED; +int TCPSocketLibevent::InternalWrite(IOBuffer* buf, int buf_len) { + int nwrite; + if (use_tcp_fastopen_ && !tcp_fastopen_connected_) { + SockaddrStorage storage; + if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) { + errno = EINVAL; + return -1; + } - return OK; + int flags = 0x20000000; // Magic flag to enable TCP_FASTOPEN. +#if defined(OS_LINUX) + // sendto() will fail with EPIPE when the system doesn't support TCP Fast + // Open. Theoretically that shouldn't happen since the caller should check + // for system support on startup, but users may dynamically disable TCP Fast + // Open via sysctl. + flags |= MSG_NOSIGNAL; +#endif // defined(OS_LINUX) + nwrite = HANDLE_EINTR(sendto(socket_, + buf->data(), + buf_len, + flags, + storage.addr, + storage.addr_len)); + tcp_fastopen_connected_ = true; + + if (nwrite < 0) { + DCHECK_NE(EPIPE, errno); + + // If errno == EINPROGRESS, that means the kernel didn't have a cookie + // and would block. The kernel is internally doing a connect() though. + // Remap EINPROGRESS to EAGAIN so we treat this the same as our other + // asynchronous cases. Note that the user buffer has not been copied to + // kernel space. + if (errno == EINPROGRESS) { + errno = EAGAIN; + fast_open_status_ = FAST_OPEN_SLOW_CONNECT_RETURN; + } else { + fast_open_status_ = FAST_OPEN_ERROR; + } + } else { + fast_open_status_ = FAST_OPEN_FAST_CONNECT_RETURN; + } + } else { + nwrite = HANDLE_EINTR(write(socket_, buf->data(), buf_len)); + } + return nwrite; } -void TCPClientSocketLibevent::RecordFastOpenStatus() { +void TCPSocketLibevent::RecordFastOpenStatus() { if (use_tcp_fastopen_ && (fast_open_status_ == FAST_OPEN_FAST_CONNECT_RETURN || fast_open_status_ == FAST_OPEN_SLOW_CONNECT_RETURN)) { @@ -795,36 +868,4 @@ void TCPClientSocketLibevent::RecordFastOpenStatus() { } } -const BoundNetLog& TCPClientSocketLibevent::NetLog() const { - return net_log_; -} - -void TCPClientSocketLibevent::SetSubresourceSpeculation() { - use_history_.set_subresource_speculation(); -} - -void TCPClientSocketLibevent::SetOmniboxSpeculation() { - use_history_.set_omnibox_speculation(); -} - -bool TCPClientSocketLibevent::WasEverUsed() const { - return use_history_.was_used_to_convey_data(); -} - -bool TCPClientSocketLibevent::UsingTCPFastOpen() const { - return use_tcp_fastopen_; -} - -bool TCPClientSocketLibevent::WasNpnNegotiated() const { - return false; -} - -NextProto TCPClientSocketLibevent::GetNegotiatedProtocol() const { - return kProtoUnknown; -} - -bool TCPClientSocketLibevent::GetSSLInfo(SSLInfo* ssl_info) { - return false; -} - } // namespace net diff --git a/chromium/net/socket/tcp_socket_libevent.h b/chromium/net/socket/tcp_socket_libevent.h new file mode 100644 index 00000000000..a50caf0ad59 --- /dev/null +++ b/chromium/net/socket/tcp_socket_libevent.h @@ -0,0 +1,235 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SOCKET_LIBEVENT_H_ +#define NET_SOCKET_TCP_SOCKET_LIBEVENT_H_ + +#include "base/basictypes.h" +#include "base/callback.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/threading/non_thread_safe.h" +#include "net/base/address_family.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" +#include "net/socket/socket_descriptor.h" + +namespace net { + +class AddressList; +class IOBuffer; +class IPEndPoint; + +class NET_EXPORT TCPSocketLibevent : public base::NonThreadSafe { + public: + TCPSocketLibevent(NetLog* net_log, const NetLog::Source& source); + virtual ~TCPSocketLibevent(); + + int Open(AddressFamily family); + // Takes ownership of |socket|. + int AdoptConnectedSocket(int socket, const IPEndPoint& peer_address); + + int Bind(const IPEndPoint& address); + + int Listen(int backlog); + int Accept(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address, + const CompletionCallback& callback); + + int Connect(const IPEndPoint& address, const CompletionCallback& callback); + bool IsConnected() const; + bool IsConnectedAndIdle() const; + + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + + int GetLocalAddress(IPEndPoint* address) const; + int GetPeerAddress(IPEndPoint* address) const; + + // Sets various socket options. + // The commonly used options for server listening sockets: + // - SetAddressReuse(true). + int SetDefaultOptionsForServer(); + // The commonly used options for client sockets and accepted sockets: + // - SetNoDelay(true); + // - SetKeepAlive(true, 45). + void SetDefaultOptionsForClient(); + int SetAddressReuse(bool allow); + bool SetReceiveBufferSize(int32 size); + bool SetSendBufferSize(int32 size); + bool SetKeepAlive(bool enable, int delay); + bool SetNoDelay(bool no_delay); + + void Close(); + + bool UsingTCPFastOpen() const; + bool IsValid() const { return socket_ != kInvalidSocket; } + + // Marks the start/end of a series of connect attempts for logging purpose. + // + // TCPClientSocket may attempt to connect to multiple addresses until it + // succeeds in establishing a connection. The corresponding log will have + // multiple NetLog::TYPE_TCP_CONNECT_ATTEMPT entries nested within a + // NetLog::TYPE_TCP_CONNECT. These methods set the start/end of + // NetLog::TYPE_TCP_CONNECT. + // + // TODO(yzshen): Change logging format and let TCPClientSocket log the + // start/end of a series of connect attempts itself. + void StartLoggingMultipleConnectAttempts(const AddressList& addresses); + void EndLoggingMultipleConnectAttempts(int net_error); + + const BoundNetLog& net_log() const { return net_log_; } + + private: + // States that a fast open socket attempt can result in. + enum FastOpenStatus { + FAST_OPEN_STATUS_UNKNOWN, + + // The initial fast open connect attempted returned synchronously, + // indicating that we had and sent a cookie along with the initial data. + FAST_OPEN_FAST_CONNECT_RETURN, + + // The initial fast open connect attempted returned asynchronously, + // indicating that we did not have a cookie for the server. + FAST_OPEN_SLOW_CONNECT_RETURN, + + // Some other error occurred on connection, so we couldn't tell if + // fast open would have worked. + FAST_OPEN_ERROR, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // had acked the data we sent. + FAST_OPEN_SYN_DATA_ACK, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and we later confirmed that the server + // had nacked the data we sent. + FAST_OPEN_SYN_DATA_NACK, + + // An attempt to do a fast open succeeded immediately + // (FAST_OPEN_FAST_CONNECT_RETURN) and our probe to determine if the + // socket was using fast open failed. + FAST_OPEN_SYN_DATA_FAILED, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and we later confirmed that the server had acked initial data. This + // should never happen (we didn't send data, so it shouldn't have + // been acked). + FAST_OPEN_NO_SYN_DATA_ACK, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and we later discovered that the server had nacked initial data. This + // is the expected case results for FAST_OPEN_SLOW_CONNECT_RETURN. + FAST_OPEN_NO_SYN_DATA_NACK, + + // An attempt to do a fast open failed (FAST_OPEN_SLOW_CONNECT_RETURN) + // and our later probe for ack/nack state failed. + FAST_OPEN_NO_SYN_DATA_FAILED, + + FAST_OPEN_MAX_VALUE + }; + + // Watcher simply forwards notifications to Closure objects set via the + // constructor. + class Watcher: public base::MessageLoopForIO::Watcher { + public: + Watcher(const base::Closure& read_ready_callback, + const base::Closure& write_ready_callback); + virtual ~Watcher(); + + // base::MessageLoopForIO::Watcher methods. + virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE; + virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE; + + private: + base::Closure read_ready_callback_; + base::Closure write_ready_callback_; + + DISALLOW_COPY_AND_ASSIGN(Watcher); + }; + + int AcceptInternal(scoped_ptr<TCPSocketLibevent>* socket, + IPEndPoint* address); + + int DoConnect(); + void DoConnectComplete(int result); + + void LogConnectBegin(const AddressList& addresses); + void LogConnectEnd(int net_error); + + void DidCompleteRead(); + void DidCompleteWrite(); + void DidCompleteConnect(); + void DidCompleteConnectOrWrite(); + void DidCompleteAccept(); + + // Internal function to write to a socket. Returns an OS error. + int InternalWrite(IOBuffer* buf, int buf_len); + + // Called when the socket is known to be in a connected state. + void RecordFastOpenStatus(); + + int socket_; + + base::MessageLoopForIO::FileDescriptorWatcher accept_socket_watcher_; + Watcher accept_watcher_; + + scoped_ptr<TCPSocketLibevent>* accept_socket_; + IPEndPoint* accept_address_; + CompletionCallback accept_callback_; + + // The socket's libevent wrappers for reads and writes. + base::MessageLoopForIO::FileDescriptorWatcher read_socket_watcher_; + base::MessageLoopForIO::FileDescriptorWatcher write_socket_watcher_; + + // The corresponding watchers for reads and writes. + Watcher read_watcher_; + Watcher write_watcher_; + + // The buffer used for reads. + scoped_refptr<IOBuffer> read_buf_; + int read_buf_len_; + + // The buffer used for writes. + scoped_refptr<IOBuffer> write_buf_; + int write_buf_len_; + + // External callback; called when read is complete. + CompletionCallback read_callback_; + + // External callback; called when write or connect is complete. + CompletionCallback write_callback_; + + // Enables experimental TCP FastOpen option. + const bool use_tcp_fastopen_; + + // True when TCP FastOpen is in use and we have done the connect. + bool tcp_fastopen_connected_; + + FastOpenStatus fast_open_status_; + + // A connect operation is pending. In this case, |write_callback_| needs to be + // called when connect is complete. + bool waiting_connect_; + + scoped_ptr<IPEndPoint> peer_address_; + // The OS error that a connect attempt last completed with. + int connect_os_error_; + + bool logging_multiple_connect_attempts_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(TCPSocketLibevent); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_SOCKET_LIBEVENT_H_ diff --git a/chromium/net/socket/tcp_socket_unittest.cc b/chromium/net/socket/tcp_socket_unittest.cc new file mode 100644 index 00000000000..a45fcba016b --- /dev/null +++ b/chromium/net/socket/tcp_socket_unittest.cc @@ -0,0 +1,263 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_socket.h" + +#include <string.h> + +#include <string> +#include <vector> + +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/address_list.h" +#include "net/base/io_buffer.h" +#include "net/base/ip_endpoint.h" +#include "net/base/net_errors.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/tcp_client_socket.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "testing/platform_test.h" + +namespace net { + +namespace { +const int kListenBacklog = 5; + +class TCPSocketTest : public PlatformTest { + protected: + TCPSocketTest() : socket_(NULL, NetLog::Source()) { + } + + void SetUpListenIPv4() { + IPEndPoint address; + ParseAddress("127.0.0.1", 0, &address); + + ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4)); + ASSERT_EQ(OK, socket_.Bind(address)); + ASSERT_EQ(OK, socket_.Listen(kListenBacklog)); + ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); + } + + void SetUpListenIPv6(bool* success) { + *success = false; + IPEndPoint address; + ParseAddress("::1", 0, &address); + + if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK || + socket_.Bind(address) != OK || + socket_.Listen(kListenBacklog) != OK) { + LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " + "disabled. Skipping the test"; + return; + } + ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); + *success = true; + } + + void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) { + IPAddressNumber ip_number; + bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); + if (!rv) + return; + *address = IPEndPoint(ip_number, port); + } + + AddressList local_address_list() const { + return AddressList(local_address_); + } + + TCPSocket socket_; + IPEndPoint local_address_; +}; + +// Test listening and accepting with a socket bound to an IPv4 address. +TEST_F(TCPSocketTest, Accept) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback connect_callback; + // TODO(yzshen): Switch to use TCPSocket when it supports client socket + // operations. + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + int result = socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); +} + +// Test Accept() callback. +TEST_F(TCPSocketTest, AcceptAsync) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + ASSERT_EQ(ERR_IO_PENDING, + socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback())); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + EXPECT_EQ(OK, accept_callback.WaitForResult()); + + EXPECT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); +} + +// Accept two connections simultaneously. +TEST_F(TCPSocketTest, Accept2Connections) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + + ASSERT_EQ(ERR_IO_PENDING, + socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback())); + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback connect_callback2; + TCPClientSocket connecting_socket2(local_address_list(), + NULL, NetLog::Source()); + connecting_socket2.Connect(connect_callback2.callback()); + + EXPECT_EQ(OK, accept_callback.WaitForResult()); + + TestCompletionCallback accept_callback2; + scoped_ptr<TCPSocket> accepted_socket2; + IPEndPoint accepted_address2; + + int result = socket_.Accept(&accepted_socket2, &accepted_address2, + accept_callback2.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback2.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + EXPECT_EQ(OK, connect_callback2.WaitForResult()); + + EXPECT_TRUE(accepted_socket.get()); + EXPECT_TRUE(accepted_socket2.get()); + EXPECT_NE(accepted_socket.get(), accepted_socket2.get()); + + EXPECT_EQ(accepted_address.address(), local_address_.address()); + EXPECT_EQ(accepted_address2.address(), local_address_.address()); +} + +// Test listening and accepting with a socket bound to an IPv6 address. +TEST_F(TCPSocketTest, AcceptIPv6) { + bool initialized = false; + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized)); + if (!initialized) + return; + + TestCompletionCallback connect_callback; + TCPClientSocket connecting_socket(local_address_list(), + NULL, NetLog::Source()); + connecting_socket.Connect(connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + int result = socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback()); + if (result == ERR_IO_PENDING) + result = accept_callback.WaitForResult(); + ASSERT_EQ(OK, result); + + EXPECT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); +} + +TEST_F(TCPSocketTest, ReadWrite) { + ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); + + TestCompletionCallback connect_callback; + TCPSocket connecting_socket(NULL, NetLog::Source()); + int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4); + ASSERT_EQ(OK, result); + connecting_socket.Connect(local_address_, connect_callback.callback()); + + TestCompletionCallback accept_callback; + scoped_ptr<TCPSocket> accepted_socket; + IPEndPoint accepted_address; + result = socket_.Accept(&accepted_socket, &accepted_address, + accept_callback.callback()); + ASSERT_EQ(OK, accept_callback.GetResult(result)); + + ASSERT_TRUE(accepted_socket.get()); + + // Both sockets should be on the loopback network interface. + EXPECT_EQ(accepted_address.address(), local_address_.address()); + + EXPECT_EQ(OK, connect_callback.WaitForResult()); + + const std::string message("test message"); + std::vector<char> buffer(message.size()); + + size_t bytes_written = 0; + while (bytes_written < message.size()) { + scoped_refptr<IOBufferWithSize> write_buffer( + new IOBufferWithSize(message.size() - bytes_written)); + memmove(write_buffer->data(), message.data() + bytes_written, + message.size() - bytes_written); + + TestCompletionCallback write_callback; + int write_result = accepted_socket->Write( + write_buffer.get(), write_buffer->size(), write_callback.callback()); + write_result = write_callback.GetResult(write_result); + ASSERT_TRUE(write_result >= 0); + bytes_written += write_result; + ASSERT_TRUE(bytes_written <= message.size()); + } + + size_t bytes_read = 0; + while (bytes_read < message.size()) { + scoped_refptr<IOBufferWithSize> read_buffer( + new IOBufferWithSize(message.size() - bytes_read)); + TestCompletionCallback read_callback; + int read_result = connecting_socket.Read( + read_buffer.get(), read_buffer->size(), read_callback.callback()); + read_result = read_callback.GetResult(read_result); + ASSERT_TRUE(read_result >= 0); + ASSERT_TRUE(bytes_read + read_result <= message.size()); + memmove(&buffer[bytes_read], read_buffer->data(), read_result); + bytes_read += read_result; + } + + std::string received_message(buffer.begin(), buffer.end()); + ASSERT_EQ(message, received_message); +} + +} // namespace +} // namespace net diff --git a/chromium/net/socket/tcp_client_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc index 9b0a5b50bf1..7d76232f962 100644 --- a/chromium/net/socket/tcp_client_socket_win.cc +++ b/chromium/net/socket/tcp_socket_win.cc @@ -1,26 +1,25 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "net/socket/tcp_client_socket_win.h" +#include "net/socket/tcp_socket_win.h" #include <mstcpip.h> -#include "base/basictypes.h" -#include "base/compiler_specific.h" +#include "base/callback_helpers.h" +#include "base/logging.h" #include "base/metrics/stats_counters.h" -#include "base/strings/string_util.h" -#include "base/win/object_watcher.h" #include "base/win/windows_version.h" +#include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_log.h" #include "net/base/net_util.h" #include "net/base/network_change_notifier.h" #include "net/base/winsock_init.h" #include "net/base/winsock_util.h" +#include "net/socket/socket_descriptor.h" #include "net/socket/socket_net_log_params.h" namespace net { @@ -28,7 +27,6 @@ namespace net { namespace { const int kTCPKeepAliveSeconds = 45; -bool g_disable_overlapped_reads = false; bool SetSocketReceiveBufferSize(SOCKET socket, int32 size) { int rv = setsockopt(socket, SOL_SOCKET, SO_RCVBUF, @@ -86,8 +84,8 @@ bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { }; DWORD bytes_returned = 0xABAB; int rv = WSAIoctl(socket, SIO_KEEPALIVE_VALS, &keepalive_vals, - sizeof(keepalive_vals), NULL, 0, - &bytes_returned, NULL, NULL); + sizeof(keepalive_vals), NULL, 0, + &bytes_returned, NULL, NULL); DCHECK(!rv) << "Could not enable TCP Keep-Alive for socket: " << socket << " [error: " << WSAGetLastError() << "]."; @@ -95,49 +93,6 @@ bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { return rv == 0; } -// Sets socket parameters. Returns the OS error code (or 0 on -// success). -int SetupSocket(SOCKET socket) { - // Increase the socket buffer sizes from the default sizes for WinXP. In - // performance testing, there is substantial benefit by increasing from 8KB - // to 64KB. - // See also: - // http://support.microsoft.com/kb/823764/EN-US - // On Vista, if we manually set these sizes, Vista turns off its receive - // window auto-tuning feature. - // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx - // Since Vista's auto-tune is better than any static value we can could set, - // only change these on pre-vista machines. - if (base::win::GetVersion() < base::win::VERSION_VISTA) { - const int32 kSocketBufferSize = 64 * 1024; - SetSocketReceiveBufferSize(socket, kSocketBufferSize); - SetSocketSendBufferSize(socket, kSocketBufferSize); - } - - DisableNagle(socket, true); - SetTCPKeepAlive(socket, true, kTCPKeepAliveSeconds); - return 0; -} - -// Creates a new socket and sets default parameters for it. Returns -// the OS error code (or 0 on success). -int CreateSocket(int family, SOCKET* socket) { - *socket = CreatePlatformSocket(family, SOCK_STREAM, IPPROTO_TCP); - if (*socket == INVALID_SOCKET) { - int os_error = WSAGetLastError(); - LOG(ERROR) << "CreatePlatformSocket failed: " << os_error; - return os_error; - } - int error = SetupSocket(*socket); - if (error) { - if (closesocket(*socket) < 0) - PLOG(ERROR) << "closesocket"; - *socket = INVALID_SOCKET; - return error; - } - return 0; -} - int MapConnectError(int os_error) { switch (os_error) { // connect fails with WSAEACCES when Windows Firewall blocks the @@ -167,31 +122,21 @@ int MapConnectError(int os_error) { //----------------------------------------------------------------------------- // This class encapsulates all the state that has to be preserved as long as -// there is a network IO operation in progress. If the owner TCPClientSocketWin -// is destroyed while an operation is in progress, the Core is detached and it +// there is a network IO operation in progress. If the owner TCPSocketWin is +// destroyed while an operation is in progress, the Core is detached and it // lives until the operation completes and the OS doesn't reference any resource // declared on this class anymore. -class TCPClientSocketWin::Core : public base::RefCounted<Core> { +class TCPSocketWin::Core : public base::RefCounted<Core> { public: - explicit Core(TCPClientSocketWin* socket); + explicit Core(TCPSocketWin* socket); // Start watching for the end of a read or write operation. void WatchForRead(); void WatchForWrite(); - // The TCPClientSocketWin is going away. + // The TCPSocketWin is going away. void Detach() { socket_ = NULL; } - // Throttle the read size based on our current slow start state. - // Returns the throttled read size. - int ThrottleReadSize(int size) { - if (slow_start_throttle_ < kMaxSlowStartThrottle) { - size = std::min(size, slow_start_throttle_); - slow_start_throttle_ *= 2; - } - return size; - } - // The separate OVERLAPPED variables for asynchronous operation. // |read_overlapped_| is used for both Connect() and Read(). // |write_overlapped_| is only used for Write(); @@ -204,9 +149,6 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> { int read_buffer_length_; int write_buffer_length_; - // Remember the state of g_disable_overlapped_reads for the duration of the - // socket based on what it was when the socket was created. - bool disable_overlapped_reads_; bool non_blocking_reads_initialized_; private: @@ -239,7 +181,7 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> { ~Core(); // The socket that created this object. - TCPClientSocketWin* socket_; + TCPSocketWin* socket_; // |reader_| handles the signals from |read_watcher_|. ReadDelegate reader_; @@ -251,26 +193,16 @@ class TCPClientSocketWin::Core : public base::RefCounted<Core> { // |write_watcher_| watches for events from Write(); base::win::ObjectWatcher write_watcher_; - // When doing reads from the socket, we try to mirror TCP's slow start. - // We do this because otherwise the async IO subsystem artifically delays - // returning data to the application. - static const int kInitialSlowStartThrottle = 1 * 1024; - static const int kMaxSlowStartThrottle = 32 * kInitialSlowStartThrottle; - int slow_start_throttle_; - DISALLOW_COPY_AND_ASSIGN(Core); }; -TCPClientSocketWin::Core::Core( - TCPClientSocketWin* socket) +TCPSocketWin::Core::Core(TCPSocketWin* socket) : read_buffer_length_(0), write_buffer_length_(0), - disable_overlapped_reads_(g_disable_overlapped_reads), non_blocking_reads_initialized_(false), socket_(socket), reader_(this), - writer_(this), - slow_start_throttle_(kInitialSlowStartThrottle) { + writer_(this) { memset(&read_overlapped_, 0, sizeof(read_overlapped_)); memset(&write_overlapped_, 0, sizeof(write_overlapped_)); @@ -278,7 +210,7 @@ TCPClientSocketWin::Core::Core( write_overlapped_.hEvent = WSACreateEvent(); } -TCPClientSocketWin::Core::~Core() { +TCPSocketWin::Core::~Core() { // Make sure the message loop is not watching this object anymore. read_watcher_.StopWatching(); write_watcher_.StopWatching(); @@ -289,37 +221,33 @@ TCPClientSocketWin::Core::~Core() { memset(&write_overlapped_, 0xaf, sizeof(write_overlapped_)); } -void TCPClientSocketWin::Core::WatchForRead() { +void TCPSocketWin::Core::WatchForRead() { // We grab an extra reference because there is an IO operation in progress. // Balanced in ReadDelegate::OnObjectSignaled(). AddRef(); read_watcher_.StartWatching(read_overlapped_.hEvent, &reader_); } -void TCPClientSocketWin::Core::WatchForWrite() { +void TCPSocketWin::Core::WatchForWrite() { // We grab an extra reference because there is an IO operation in progress. // Balanced in WriteDelegate::OnObjectSignaled(). AddRef(); write_watcher_.StartWatching(write_overlapped_.hEvent, &writer_); } -void TCPClientSocketWin::Core::ReadDelegate::OnObjectSignaled( - HANDLE object) { +void TCPSocketWin::Core::ReadDelegate::OnObjectSignaled(HANDLE object) { DCHECK_EQ(object, core_->read_overlapped_.hEvent); if (core_->socket_) { - if (core_->socket_->waiting_connect()) { + if (core_->socket_->waiting_connect_) core_->socket_->DidCompleteConnect(); - } else if (core_->disable_overlapped_reads_) { + else core_->socket_->DidSignalRead(); - } else { - core_->socket_->DidCompleteRead(); - } } core_->Release(); } -void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled( +void TCPSocketWin::Core::WriteDelegate::OnObjectSignaled( HANDLE object) { DCHECK_EQ(object, core_->write_overlapped_.hEvent); if (core_->socket_) @@ -330,281 +258,170 @@ void TCPClientSocketWin::Core::WriteDelegate::OnObjectSignaled( //----------------------------------------------------------------------------- -TCPClientSocketWin::TCPClientSocketWin(const AddressList& addresses, - net::NetLog* net_log, - const net::NetLog::Source& source) +TCPSocketWin::TCPSocketWin(net::NetLog* net_log, + const net::NetLog::Source& source) : socket_(INVALID_SOCKET), - bound_socket_(INVALID_SOCKET), - addresses_(addresses), - current_address_index_(-1), + accept_event_(WSA_INVALID_EVENT), + accept_socket_(NULL), + accept_address_(NULL), + waiting_connect_(false), waiting_read_(false), waiting_write_(false), - next_connect_state_(CONNECT_STATE_NONE), connect_os_error_(0), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - previously_disconnected_(false) { + logging_multiple_connect_attempts_(false), + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) { net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, source.ToEventParametersCallback()); EnsureWinsockInit(); } -TCPClientSocketWin::~TCPClientSocketWin() { - Disconnect(); +TCPSocketWin::~TCPSocketWin() { + Close(); net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE); } -int TCPClientSocketWin::AdoptSocket(SOCKET socket) { +int TCPSocketWin::Open(AddressFamily family) { + DCHECK(CalledOnValidThread()); DCHECK_EQ(socket_, INVALID_SOCKET); - int error = SetupSocket(socket); - if (error) - return MapSystemError(error); - - socket_ = socket; - SetNonBlocking(socket_); - - core_ = new Core(this); - current_address_index_ = 0; - use_history_.set_was_ever_connected(); - - return OK; -} - -int TCPClientSocketWin::Bind(const IPEndPoint& address) { - if (current_address_index_ >= 0 || bind_address_.get()) { - // Cannot bind the socket if we are already connected or connecting. - return ERR_UNEXPECTED; + socket_ = CreatePlatformSocket(ConvertAddressFamily(family), SOCK_STREAM, + IPPROTO_TCP); + if (socket_ == INVALID_SOCKET) { + PLOG(ERROR) << "CreatePlatformSocket() returned an error"; + return MapSystemError(WSAGetLastError()); } - SockaddrStorage storage; - if (!address.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - - // Create |bound_socket_| and try to bind it to |address|. - int error = CreateSocket(address.GetSockAddrFamily(), &bound_socket_); - if (error) - return MapSystemError(error); - - if (bind(bound_socket_, storage.addr, storage.addr_len)) { - error = errno; - if (closesocket(bound_socket_) < 0) - PLOG(ERROR) << "closesocket"; - bound_socket_ = INVALID_SOCKET; - return MapSystemError(error); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(WSAGetLastError()); + Close(); + return result; } - bind_address_.reset(new IPEndPoint(address)); - - return 0; + return OK; } - -int TCPClientSocketWin::Connect(const CompletionCallback& callback) { +int TCPSocketWin::AdoptConnectedSocket(SOCKET socket, + const IPEndPoint& peer_address) { DCHECK(CalledOnValidThread()); + DCHECK_EQ(socket_, INVALID_SOCKET); + DCHECK(!core_); - // If already connected, then just return OK. - if (socket_ != INVALID_SOCKET) - return OK; - - base::StatsCounter connects("tcp.connect"); - connects.Increment(); - - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, - addresses_.CreateNetLogCallback()); - - // We will try to connect to each address in addresses_. Start with the - // first one in the list. - next_connect_state_ = CONNECT_STATE_CONNECT; - current_address_index_ = 0; + socket_ = socket; - int rv = DoConnectLoop(OK); - if (rv == ERR_IO_PENDING) { - // Synchronous operation not supported. - DCHECK(!callback.is_null()); - // TODO(ajwong): Is setting read_callback_ the right thing to do here?? - read_callback_ = callback; - } else { - LogConnectCompletion(rv); + if (SetNonBlocking(socket_)) { + int result = MapSystemError(WSAGetLastError()); + Close(); + return result; } - return rv; -} - -int TCPClientSocketWin::DoConnectLoop(int result) { - DCHECK_NE(next_connect_state_, CONNECT_STATE_NONE); - - int rv = result; - do { - ConnectState state = next_connect_state_; - next_connect_state_ = CONNECT_STATE_NONE; - switch (state) { - case CONNECT_STATE_CONNECT: - DCHECK_EQ(OK, rv); - rv = DoConnect(); - break; - case CONNECT_STATE_CONNECT_COMPLETE: - rv = DoConnectComplete(rv); - break; - default: - LOG(DFATAL) << "bad state " << state; - rv = ERR_UNEXPECTED; - break; - } - } while (rv != ERR_IO_PENDING && next_connect_state_ != CONNECT_STATE_NONE); + core_ = new Core(this); + peer_address_.reset(new IPEndPoint(peer_address)); - return rv; + return OK; } -int TCPClientSocketWin::DoConnect() { - DCHECK_GE(current_address_index_, 0); - DCHECK_LT(current_address_index_, static_cast<int>(addresses_.size())); - DCHECK_EQ(0, connect_os_error_); +int TCPSocketWin::Bind(const IPEndPoint& address) { + DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); - const IPEndPoint& endpoint = addresses_[current_address_index_]; + SockaddrStorage storage; + if (!address.ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_ADDRESS_INVALID; - if (previously_disconnected_) { - use_history_.Reset(); - previously_disconnected_ = false; + int result = bind(socket_, storage.addr, storage.addr_len); + if (result < 0) { + PLOG(ERROR) << "bind() returned an error"; + return MapSystemError(WSAGetLastError()); } - net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - CreateNetLogIPEndPointCallback(&endpoint)); + return OK; +} - next_connect_state_ = CONNECT_STATE_CONNECT_COMPLETE; +int TCPSocketWin::Listen(int backlog) { + DCHECK(CalledOnValidThread()); + DCHECK_GT(backlog, 0); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK_EQ(accept_event_, WSA_INVALID_EVENT); - if (bound_socket_ != INVALID_SOCKET) { - DCHECK(bind_address_.get()); - socket_ = bound_socket_; - bound_socket_ = INVALID_SOCKET; - } else { - connect_os_error_ = CreateSocket(endpoint.GetSockAddrFamily(), &socket_); - if (connect_os_error_ != 0) - return MapSystemError(connect_os_error_); - - if (bind_address_.get()) { - SockaddrStorage storage; - if (!bind_address_->ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - if (bind(socket_, storage.addr, storage.addr_len)) - return MapSystemError(errno); - } + accept_event_ = WSACreateEvent(); + if (accept_event_ == WSA_INVALID_EVENT) { + PLOG(ERROR) << "WSACreateEvent()"; + return MapSystemError(WSAGetLastError()); } - DCHECK(!core_); - core_ = new Core(this); - // WSAEventSelect sets the socket to non-blocking mode as a side effect. - // Our connect() and recv() calls require that the socket be non-blocking. - WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT); - - SockaddrStorage storage; - if (!endpoint.ToSockAddr(storage.addr, &storage.addr_len)) - return ERR_INVALID_ARGUMENT; - if (!connect(socket_, storage.addr, storage.addr_len)) { - // Connected without waiting! - // - // The MSDN page for connect says: - // With a nonblocking socket, the connection attempt cannot be completed - // immediately. In this case, connect will return SOCKET_ERROR, and - // WSAGetLastError will return WSAEWOULDBLOCK. - // which implies that for a nonblocking socket, connect never returns 0. - // It's not documented whether the event object will be signaled or not - // if connect does return 0. So the code below is essentially dead code - // and we don't know if it's correct. - NOTREACHED(); - - if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) - return OK; - } else { - int os_error = WSAGetLastError(); - if (os_error != WSAEWOULDBLOCK) { - LOG(ERROR) << "connect failed: " << os_error; - connect_os_error_ = os_error; - return MapConnectError(os_error); - } + int result = listen(socket_, backlog); + if (result < 0) { + PLOG(ERROR) << "listen() returned an error"; + return MapSystemError(WSAGetLastError()); } - core_->WatchForRead(); - return ERR_IO_PENDING; + return OK; } -int TCPClientSocketWin::DoConnectComplete(int result) { - // Log the end of this attempt (and any OS error it threw). - int os_error = connect_os_error_; - connect_os_error_ = 0; - if (result != OK) { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, - NetLog::IntegerCallback("os_error", os_error)); - } else { - net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); - } +int TCPSocketWin::Accept(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address, + const CompletionCallback& callback) { + DCHECK(CalledOnValidThread()); + DCHECK(socket); + DCHECK(address); + DCHECK(!callback.is_null()); + DCHECK(accept_callback_.is_null()); - if (result == OK) { - use_history_.set_was_ever_connected(); - return OK; // Done! - } + net_log_.BeginEvent(NetLog::TYPE_TCP_ACCEPT); - // Close whatever partially connected socket we currently have. - DoDisconnect(); + int result = AcceptInternal(socket, address); - // Try to fall back to the next address in the list. - if (current_address_index_ + 1 < static_cast<int>(addresses_.size())) { - next_connect_state_ = CONNECT_STATE_CONNECT; - ++current_address_index_; - return OK; + if (result == ERR_IO_PENDING) { + // Start watching. + WSAEventSelect(socket_, accept_event_, FD_ACCEPT); + accept_watcher_.StartWatching(accept_event_, this); + + accept_socket_ = socket; + accept_address_ = address; + accept_callback_ = callback; } - // Otherwise there is nothing to fall back to, so give up. return result; } -void TCPClientSocketWin::Disconnect() { +int TCPSocketWin::Connect(const IPEndPoint& address, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); + DCHECK_NE(socket_, INVALID_SOCKET); + DCHECK(!waiting_connect_); - DoDisconnect(); - current_address_index_ = -1; - bind_address_.reset(); -} + // |peer_address_| and |core_| will be non-NULL if Connect() has been called. + // Unless Close() is called to reset the internal state, a second call to + // Connect() is not allowed. + // Please note that we enforce this even if the previous Connect() has + // completed and failed. Although it is allowed to connect the same |socket_| + // again after a connection attempt failed on Windows, it results in + // unspecified behavior according to POSIX. Therefore, we make it behave in + // the same way as TCPSocketLibevent. + DCHECK(!peer_address_ && !core_); -void TCPClientSocketWin::DoDisconnect() { - DCHECK(CalledOnValidThread()); + if (!logging_multiple_connect_attempts_) + LogConnectBegin(AddressList(address)); - if (socket_ == INVALID_SOCKET) - return; + peer_address_.reset(new IPEndPoint(address)); - // Note: don't use CancelIo to cancel pending IO because it doesn't work - // when there is a Winsock layered service provider. - - // In most socket implementations, closing a socket results in a graceful - // connection shutdown, but in Winsock we have to call shutdown explicitly. - // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure" - // at http://msdn.microsoft.com/en-us/library/ms738547.aspx - shutdown(socket_, SD_SEND); - - // This cancels any pending IO. - closesocket(socket_); - socket_ = INVALID_SOCKET; - - if (waiting_connect()) { - // We closed the socket, so this notification will never come. - // From MSDN' WSAEventSelect documentation: - // "Closing a socket with closesocket also cancels the association and - // selection of network events specified in WSAEventSelect for the socket". - core_->Release(); + int rv = DoConnect(); + if (rv == ERR_IO_PENDING) { + // Synchronous operation not supported. + DCHECK(!callback.is_null()); + read_callback_ = callback; + waiting_connect_ = true; + } else { + DoConnectComplete(rv); } - waiting_read_ = false; - waiting_write_ = false; - - core_->Detach(); - core_ = NULL; - - previously_disconnected_ = true; + return rv; } -bool TCPClientSocketWin::IsConnected() const { +bool TCPSocketWin::IsConnected() const { DCHECK(CalledOnValidThread()); - if (socket_ == INVALID_SOCKET || waiting_connect()) + if (socket_ == INVALID_SOCKET || waiting_connect_) return false; if (waiting_read_) @@ -621,10 +438,10 @@ bool TCPClientSocketWin::IsConnected() const { return true; } -bool TCPClientSocketWin::IsConnectedAndIdle() const { +bool TCPSocketWin::IsConnectedAndIdle() const { DCHECK(CalledOnValidThread()); - if (socket_ == INVALID_SOCKET || waiting_connect()) + if (socket_ == INVALID_SOCKET || waiting_connect_) return false; if (waiting_read_) @@ -642,68 +459,9 @@ bool TCPClientSocketWin::IsConnectedAndIdle() const { return true; } -int TCPClientSocketWin::GetPeerAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - if (!IsConnected()) - return ERR_SOCKET_NOT_CONNECTED; - *address = addresses_[current_address_index_]; - return OK; -} - -int TCPClientSocketWin::GetLocalAddress(IPEndPoint* address) const { - DCHECK(CalledOnValidThread()); - DCHECK(address); - if (socket_ == INVALID_SOCKET) { - if (bind_address_.get()) { - *address = *bind_address_; - return OK; - } - return ERR_SOCKET_NOT_CONNECTED; - } - - struct sockaddr_storage addr_storage; - socklen_t addr_len = sizeof(addr_storage); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - if (getsockname(socket_, addr, &addr_len)) - return MapSystemError(WSAGetLastError()); - if (!address->FromSockAddr(addr, addr_len)) - return ERR_FAILED; - return OK; -} - -void TCPClientSocketWin::SetSubresourceSpeculation() { - use_history_.set_subresource_speculation(); -} - -void TCPClientSocketWin::SetOmniboxSpeculation() { - use_history_.set_omnibox_speculation(); -} - -bool TCPClientSocketWin::WasEverUsed() const { - return use_history_.was_used_to_convey_data(); -} - -bool TCPClientSocketWin::UsingTCPFastOpen() const { - // Not supported on windows. - return false; -} - -bool TCPClientSocketWin::WasNpnNegotiated() const { - return false; -} - -NextProto TCPClientSocketWin::GetNegotiatedProtocol() const { - return kProtoUnknown; -} - -bool TCPClientSocketWin::GetSSLInfo(SSLInfo* ssl_info) { - return false; -} - -int TCPClientSocketWin::Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketWin::Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_read_); @@ -713,9 +471,9 @@ int TCPClientSocketWin::Read(IOBuffer* buf, return DoRead(buf, buf_len, callback); } -int TCPClientSocketWin::Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) { +int TCPSocketWin::Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) { DCHECK(CalledOnValidThread()); DCHECK_NE(socket_, INVALID_SOCKET); DCHECK(!waiting_write_); @@ -747,8 +505,6 @@ int TCPClientSocketWin::Write(IOBuffer* buf, } base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(rv); - if (rv > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, rv, buf->data()); return rv; @@ -770,29 +526,294 @@ int TCPClientSocketWin::Write(IOBuffer* buf, return ERR_IO_PENDING; } -bool TCPClientSocketWin::SetReceiveBufferSize(int32 size) { +int TCPSocketWin::GetLocalAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + + SockaddrStorage storage; + if (getsockname(socket_, storage.addr, &storage.addr_len)) + return MapSystemError(WSAGetLastError()); + if (!address->FromSockAddr(storage.addr, storage.addr_len)) + return ERR_ADDRESS_INVALID; + + return OK; +} + +int TCPSocketWin::GetPeerAddress(IPEndPoint* address) const { + DCHECK(CalledOnValidThread()); + DCHECK(address); + if (!IsConnected()) + return ERR_SOCKET_NOT_CONNECTED; + *address = *peer_address_; + return OK; +} + +int TCPSocketWin::SetDefaultOptionsForServer() { + return SetExclusiveAddrUse(); +} + +void TCPSocketWin::SetDefaultOptionsForClient() { + // Increase the socket buffer sizes from the default sizes for WinXP. In + // performance testing, there is substantial benefit by increasing from 8KB + // to 64KB. + // See also: + // http://support.microsoft.com/kb/823764/EN-US + // On Vista, if we manually set these sizes, Vista turns off its receive + // window auto-tuning feature. + // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx + // Since Vista's auto-tune is better than any static value we can could set, + // only change these on pre-vista machines. + if (base::win::GetVersion() < base::win::VERSION_VISTA) { + const int32 kSocketBufferSize = 64 * 1024; + SetSocketReceiveBufferSize(socket_, kSocketBufferSize); + SetSocketSendBufferSize(socket_, kSocketBufferSize); + } + + DisableNagle(socket_, true); + SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds); +} + +int TCPSocketWin::SetExclusiveAddrUse() { + // On Windows, a bound end point can be hijacked by another process by + // setting SO_REUSEADDR. Therefore a Windows-only option SO_EXCLUSIVEADDRUSE + // was introduced in Windows NT 4.0 SP4. If the socket that is bound to the + // end point has SO_EXCLUSIVEADDRUSE enabled, it is not possible for another + // socket to forcibly bind to the end point until the end point is unbound. + // It is recommend that all server applications must use SO_EXCLUSIVEADDRUSE. + // MSDN: http://goo.gl/M6fjQ. + // + // Unlike on *nix, on Windows a TCP server socket can always bind to an end + // point in TIME_WAIT state without setting SO_REUSEADDR, therefore it is not + // needed here. + // + // SO_EXCLUSIVEADDRUSE will prevent a TCP client socket from binding to an end + // point in TIME_WAIT status. It does not have this effect for a TCP server + // socket. + + BOOL true_value = 1; + int rv = setsockopt(socket_, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast<const char*>(&true_value), + sizeof(true_value)); + if (rv < 0) + return MapSystemError(errno); + return OK; +} + +bool TCPSocketWin::SetReceiveBufferSize(int32 size) { DCHECK(CalledOnValidThread()); return SetSocketReceiveBufferSize(socket_, size); } -bool TCPClientSocketWin::SetSendBufferSize(int32 size) { +bool TCPSocketWin::SetSendBufferSize(int32 size) { DCHECK(CalledOnValidThread()); return SetSocketSendBufferSize(socket_, size); } -bool TCPClientSocketWin::SetKeepAlive(bool enable, int delay) { +bool TCPSocketWin::SetKeepAlive(bool enable, int delay) { return SetTCPKeepAlive(socket_, enable, delay); } -bool TCPClientSocketWin::SetNoDelay(bool no_delay) { +bool TCPSocketWin::SetNoDelay(bool no_delay) { return DisableNagle(socket_, no_delay); } -void TCPClientSocketWin::DisableOverlappedReads() { - g_disable_overlapped_reads = true; +void TCPSocketWin::Close() { + DCHECK(CalledOnValidThread()); + + if (socket_ != INVALID_SOCKET) { + // Note: don't use CancelIo to cancel pending IO because it doesn't work + // when there is a Winsock layered service provider. + + // In most socket implementations, closing a socket results in a graceful + // connection shutdown, but in Winsock we have to call shutdown explicitly. + // See the MSDN page "Graceful Shutdown, Linger Options, and Socket Closure" + // at http://msdn.microsoft.com/en-us/library/ms738547.aspx + shutdown(socket_, SD_SEND); + + // This cancels any pending IO. + if (closesocket(socket_) < 0) + PLOG(ERROR) << "closesocket"; + socket_ = INVALID_SOCKET; + } + + if (accept_event_) { + WSACloseEvent(accept_event_); + accept_event_ = WSA_INVALID_EVENT; + } + + if (!accept_callback_.is_null()) { + accept_watcher_.StopWatching(); + accept_socket_ = NULL; + accept_address_ = NULL; + accept_callback_.Reset(); + } + + if (core_) { + if (waiting_connect_) { + // We closed the socket, so this notification will never come. + // From MSDN' WSAEventSelect documentation: + // "Closing a socket with closesocket also cancels the association and + // selection of network events specified in WSAEventSelect for the + // socket". + core_->Release(); + } + core_->Detach(); + core_ = NULL; + } + + waiting_connect_ = false; + waiting_read_ = false; + waiting_write_ = false; + + read_callback_.Reset(); + write_callback_.Reset(); + peer_address_.reset(); + connect_os_error_ = 0; } -void TCPClientSocketWin::LogConnectCompletion(int net_error) { +bool TCPSocketWin::UsingTCPFastOpen() const { + // Not supported on windows. + return false; +} + +void TCPSocketWin::StartLoggingMultipleConnectAttempts( + const AddressList& addresses) { + if (!logging_multiple_connect_attempts_) { + logging_multiple_connect_attempts_ = true; + LogConnectBegin(addresses); + } else { + NOTREACHED(); + } +} + +void TCPSocketWin::EndLoggingMultipleConnectAttempts(int net_error) { + if (logging_multiple_connect_attempts_) { + LogConnectEnd(net_error); + logging_multiple_connect_attempts_ = false; + } else { + NOTREACHED(); + } +} + +int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address) { + SockaddrStorage storage; + int new_socket = accept(socket_, storage.addr, &storage.addr_len); + if (new_socket < 0) { + int net_error = MapSystemError(WSAGetLastError()); + if (net_error != ERR_IO_PENDING) + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, net_error); + return net_error; + } + + IPEndPoint ip_end_point; + if (!ip_end_point.FromSockAddr(storage.addr, storage.addr_len)) { + NOTREACHED(); + if (closesocket(new_socket) < 0) + PLOG(ERROR) << "closesocket"; + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, ERR_FAILED); + return ERR_FAILED; + } + scoped_ptr<TCPSocketWin> tcp_socket(new TCPSocketWin( + net_log_.net_log(), net_log_.source())); + int adopt_result = tcp_socket->AdoptConnectedSocket(new_socket, ip_end_point); + if (adopt_result != OK) { + net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); + return adopt_result; + } + *socket = tcp_socket.Pass(); + *address = ip_end_point; + net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, + CreateNetLogIPEndPointCallback(&ip_end_point)); + return OK; +} + +void TCPSocketWin::OnObjectSignaled(HANDLE object) { + WSANETWORKEVENTS ev; + if (WSAEnumNetworkEvents(socket_, accept_event_, &ev) == SOCKET_ERROR) { + PLOG(ERROR) << "WSAEnumNetworkEvents()"; + return; + } + + if (ev.lNetworkEvents & FD_ACCEPT) { + int result = AcceptInternal(accept_socket_, accept_address_); + if (result != ERR_IO_PENDING) { + accept_socket_ = NULL; + accept_address_ = NULL; + base::ResetAndReturn(&accept_callback_).Run(result); + } + } +} + +int TCPSocketWin::DoConnect() { + DCHECK_EQ(connect_os_error_, 0); + DCHECK(!core_); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + CreateNetLogIPEndPointCallback(peer_address_.get())); + + core_ = new Core(this); + // WSAEventSelect sets the socket to non-blocking mode as a side effect. + // Our connect() and recv() calls require that the socket be non-blocking. + WSAEventSelect(socket_, core_->read_overlapped_.hEvent, FD_CONNECT); + + SockaddrStorage storage; + if (!peer_address_->ToSockAddr(storage.addr, &storage.addr_len)) + return ERR_INVALID_ARGUMENT; + if (!connect(socket_, storage.addr, storage.addr_len)) { + // Connected without waiting! + // + // The MSDN page for connect says: + // With a nonblocking socket, the connection attempt cannot be completed + // immediately. In this case, connect will return SOCKET_ERROR, and + // WSAGetLastError will return WSAEWOULDBLOCK. + // which implies that for a nonblocking socket, connect never returns 0. + // It's not documented whether the event object will be signaled or not + // if connect does return 0. So the code below is essentially dead code + // and we don't know if it's correct. + NOTREACHED(); + + if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) + return OK; + } else { + int os_error = WSAGetLastError(); + if (os_error != WSAEWOULDBLOCK) { + LOG(ERROR) << "connect failed: " << os_error; + connect_os_error_ = os_error; + int rv = MapConnectError(os_error); + CHECK_NE(ERR_IO_PENDING, rv); + return rv; + } + } + + core_->WatchForRead(); + return ERR_IO_PENDING; +} + +void TCPSocketWin::DoConnectComplete(int result) { + // Log the end of this attempt (and any OS error it threw). + int os_error = connect_os_error_; + connect_os_error_ = 0; + if (result != OK) { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT, + NetLog::IntegerCallback("os_error", os_error)); + } else { + net_log_.EndEvent(NetLog::TYPE_TCP_CONNECT_ATTEMPT); + } + + if (!logging_multiple_connect_attempts_) + LogConnectEnd(result); +} + +void TCPSocketWin::LogConnectBegin(const AddressList& addresses) { + base::StatsCounter connects("tcp.connect"); + connects.Increment(); + + net_log_.BeginEvent(NetLog::TYPE_TCP_CONNECT, + addresses.CreateNetLogCallback()); +} + +void TCPSocketWin::LogConnectEnd(int net_error) { if (net_error == OK) UpdateConnectionTypeHistograms(CONNECTION_ANY); @@ -820,66 +841,30 @@ void TCPClientSocketWin::LogConnectCompletion(int net_error) { sizeof(source_address))); } -int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len, - const CompletionCallback& callback) { - if (core_->disable_overlapped_reads_) { - if (!core_->non_blocking_reads_initialized_) { - WSAEventSelect(socket_, core_->read_overlapped_.hEvent, - FD_READ | FD_CLOSE); - core_->non_blocking_reads_initialized_ = true; - } - int rv = recv(socket_, buf->data(), buf_len, 0); - if (rv == SOCKET_ERROR) { - int os_error = WSAGetLastError(); - if (os_error != WSAEWOULDBLOCK) { - int net_error = MapSystemError(os_error); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(net_error, os_error)); - return net_error; - } - } else { - base::StatsCounter read_bytes("tcp.read_bytes"); - if (rv > 0) { - use_history_.set_was_used_to_convey_data(); - read_bytes.Add(rv); - } - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, - buf->data()); - return rv; +int TCPSocketWin::DoRead(IOBuffer* buf, int buf_len, + const CompletionCallback& callback) { + if (!core_->non_blocking_reads_initialized_) { + WSAEventSelect(socket_, core_->read_overlapped_.hEvent, + FD_READ | FD_CLOSE); + core_->non_blocking_reads_initialized_ = true; + } + int rv = recv(socket_, buf->data(), buf_len, 0); + if (rv == SOCKET_ERROR) { + int os_error = WSAGetLastError(); + if (os_error != WSAEWOULDBLOCK) { + int net_error = MapSystemError(os_error); + net_log_.AddEvent( + NetLog::TYPE_SOCKET_READ_ERROR, + CreateNetLogSocketErrorCallback(net_error, os_error)); + return net_error; } } else { - buf_len = core_->ThrottleReadSize(buf_len); - - WSABUF read_buffer; - read_buffer.len = buf_len; - read_buffer.buf = buf->data(); - - // TODO(wtc): Remove the assertion after enough testing. - AssertEventNotSignaled(core_->read_overlapped_.hEvent); - DWORD num; - DWORD flags = 0; - int rv = WSARecv(socket_, &read_buffer, 1, &num, &flags, - &core_->read_overlapped_, NULL); - if (rv == 0) { - if (ResetEventIfSignaled(core_->read_overlapped_.hEvent)) { - base::StatsCounter read_bytes("tcp.read_bytes"); - if (num > 0) { - use_history_.set_was_used_to_convey_data(); - read_bytes.Add(num); - } - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, num, - buf->data()); - return static_cast<int>(num); - } - } else { - int os_error = WSAGetLastError(); - if (os_error != WSA_IO_PENDING) { - int net_error = MapSystemError(os_error); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(net_error, os_error)); - return net_error; - } - } + base::StatsCounter read_bytes("tcp.read_bytes"); + if (rv > 0) + read_bytes.Add(rv); + net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, rv, + buf->data()); + return rv; } waiting_read_ = true; @@ -890,28 +875,9 @@ int TCPClientSocketWin::DoRead(IOBuffer* buf, int buf_len, return ERR_IO_PENDING; } -void TCPClientSocketWin::DoReadCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); +void TCPSocketWin::DidCompleteConnect() { + DCHECK(waiting_connect_); DCHECK(!read_callback_.is_null()); - - // Since Run may result in Read being called, clear read_callback_ up front. - CompletionCallback c = read_callback_; - read_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketWin::DoWriteCallback(int rv) { - DCHECK_NE(rv, ERR_IO_PENDING); - DCHECK(!write_callback_.is_null()); - - // since Run may result in Write being called, clear write_callback_ up front. - CompletionCallback c = write_callback_; - write_callback_.Reset(); - c.Run(rv); -} - -void TCPClientSocketWin::DidCompleteConnect() { - DCHECK_EQ(next_connect_state_, CONNECT_STATE_CONNECT_COMPLETE); int result; WSANETWORKEVENTS events; @@ -931,42 +897,16 @@ void TCPClientSocketWin::DidCompleteConnect() { } connect_os_error_ = os_error; - rv = DoConnectLoop(result); - if (rv != ERR_IO_PENDING) { - LogConnectCompletion(rv); - DoReadCallback(rv); - } -} + DoConnectComplete(result); + waiting_connect_ = false; -void TCPClientSocketWin::DidCompleteRead() { - DCHECK(waiting_read_); - DWORD num_bytes, flags; - BOOL ok = WSAGetOverlappedResult(socket_, &core_->read_overlapped_, - &num_bytes, FALSE, &flags); - waiting_read_ = false; - int rv; - if (ok) { - base::StatsCounter read_bytes("tcp.read_bytes"); - read_bytes.Add(num_bytes); - if (num_bytes > 0) - use_history_.set_was_used_to_convey_data(); - net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_RECEIVED, - num_bytes, core_->read_iobuffer_->data()); - rv = static_cast<int>(num_bytes); - } else { - int os_error = WSAGetLastError(); - rv = MapSystemError(os_error); - net_log_.AddEvent(NetLog::TYPE_SOCKET_READ_ERROR, - CreateNetLogSocketErrorCallback(rv, os_error)); - } - WSAResetEvent(core_->read_overlapped_.hEvent); - core_->read_iobuffer_ = NULL; - core_->read_buffer_length_ = 0; - DoReadCallback(rv); + DCHECK_NE(result, ERR_IO_PENDING); + base::ResetAndReturn(&read_callback_).Run(result); } -void TCPClientSocketWin::DidCompleteWrite() { +void TCPSocketWin::DidCompleteWrite() { DCHECK(waiting_write_); + DCHECK(!write_callback_.is_null()); DWORD num_bytes, flags; BOOL ok = WSAGetOverlappedResult(socket_, &core_->write_overlapped_, @@ -991,18 +931,21 @@ void TCPClientSocketWin::DidCompleteWrite() { } else { base::StatsCounter write_bytes("tcp.write_bytes"); write_bytes.Add(num_bytes); - if (num_bytes > 0) - use_history_.set_was_used_to_convey_data(); net_log_.AddByteTransferEvent(NetLog::TYPE_SOCKET_BYTES_SENT, num_bytes, core_->write_iobuffer_->data()); } } + core_->write_iobuffer_ = NULL; - DoWriteCallback(rv); + + DCHECK_NE(rv, ERR_IO_PENDING); + base::ResetAndReturn(&write_callback_).Run(rv); } -void TCPClientSocketWin::DidSignalRead() { +void TCPSocketWin::DidSignalRead() { DCHECK(waiting_read_); + DCHECK(!read_callback_.is_null()); + int os_error = 0; WSANETWORKEVENTS network_events; int rv = WSAEnumNetworkEvents(socket_, core_->read_overlapped_.hEvent, @@ -1036,10 +979,14 @@ void TCPClientSocketWin::DidSignalRead() { core_->WatchForRead(); return; } + waiting_read_ = false; core_->read_iobuffer_ = NULL; core_->read_buffer_length_ = 0; - DoReadCallback(rv); + + DCHECK_NE(rv, ERR_IO_PENDING); + base::ResetAndReturn(&read_callback_).Run(rv); } } // namespace net + diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h new file mode 100644 index 00000000000..df5fbf09aec --- /dev/null +++ b/chromium/net/socket/tcp_socket_win.h @@ -0,0 +1,150 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SOCKET_TCP_SOCKET_WIN_H_ +#define NET_SOCKET_TCP_SOCKET_WIN_H_ + +#include <winsock2.h> + +#include "base/basictypes.h" +#include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/threading/non_thread_safe.h" +#include "base/win/object_watcher.h" +#include "net/base/address_family.h" +#include "net/base/completion_callback.h" +#include "net/base/net_export.h" +#include "net/base/net_log.h" + +namespace net { + +class AddressList; +class IOBuffer; +class IPEndPoint; + +class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), + public base::win::ObjectWatcher::Delegate { + public: + TCPSocketWin(NetLog* net_log, const NetLog::Source& source); + virtual ~TCPSocketWin(); + + int Open(AddressFamily family); + // Takes ownership of |socket|. + int AdoptConnectedSocket(SOCKET socket, const IPEndPoint& peer_address); + + int Bind(const IPEndPoint& address); + + int Listen(int backlog); + int Accept(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address, + const CompletionCallback& callback); + + int Connect(const IPEndPoint& address, const CompletionCallback& callback); + bool IsConnected() const; + bool IsConnectedAndIdle() const; + + // Multiple outstanding requests are not supported. + // Full duplex mode (reading and writing at the same time) is supported. + int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + + int GetLocalAddress(IPEndPoint* address) const; + int GetPeerAddress(IPEndPoint* address) const; + + // Sets various socket options. + // The commonly used options for server listening sockets: + // - SetExclusiveAddrUse(). + int SetDefaultOptionsForServer(); + // The commonly used options for client sockets and accepted sockets: + // - Increase the socket buffer sizes for WinXP; + // - SetNoDelay(true); + // - SetKeepAlive(true, 45). + void SetDefaultOptionsForClient(); + int SetExclusiveAddrUse(); + bool SetReceiveBufferSize(int32 size); + bool SetSendBufferSize(int32 size); + bool SetKeepAlive(bool enable, int delay); + bool SetNoDelay(bool no_delay); + + void Close(); + + bool UsingTCPFastOpen() const; + bool IsValid() const { return socket_ != INVALID_SOCKET; } + + // Marks the start/end of a series of connect attempts for logging purpose. + // + // TCPClientSocket may attempt to connect to multiple addresses until it + // succeeds in establishing a connection. The corresponding log will have + // multiple NetLog::TYPE_TCP_CONNECT_ATTEMPT entries nested within a + // NetLog::TYPE_TCP_CONNECT. These methods set the start/end of + // NetLog::TYPE_TCP_CONNECT. + // + // TODO(yzshen): Change logging format and let TCPClientSocket log the + // start/end of a series of connect attempts itself. + void StartLoggingMultipleConnectAttempts(const AddressList& addresses); + void EndLoggingMultipleConnectAttempts(int net_error); + + const BoundNetLog& net_log() const { return net_log_; } + + private: + class Core; + + // base::ObjectWatcher::Delegate implementation. + virtual void OnObjectSignaled(HANDLE object) OVERRIDE; + + int AcceptInternal(scoped_ptr<TCPSocketWin>* socket, + IPEndPoint* address); + + int DoConnect(); + void DoConnectComplete(int result); + + void LogConnectBegin(const AddressList& addresses); + void LogConnectEnd(int net_error); + + int DoRead(IOBuffer* buf, int buf_len, const CompletionCallback& callback); + void DidCompleteConnect(); + void DidCompleteWrite(); + void DidSignalRead(); + + SOCKET socket_; + + HANDLE accept_event_; + base::win::ObjectWatcher accept_watcher_; + + scoped_ptr<TCPSocketWin>* accept_socket_; + IPEndPoint* accept_address_; + CompletionCallback accept_callback_; + + // The various states that the socket could be in. + bool waiting_connect_; + bool waiting_read_; + bool waiting_write_; + + // The core of the socket that can live longer than the socket itself. We pass + // resources to the Windows async IO functions and we have to make sure that + // they are not destroyed while the OS still references them. + scoped_refptr<Core> core_; + + // External callback; called when connect or read is complete. + CompletionCallback read_callback_; + + // External callback; called when write is complete. + CompletionCallback write_callback_; + + scoped_ptr<IPEndPoint> peer_address_; + // The OS error that a connect attempt last completed with. + int connect_os_error_; + + bool logging_multiple_connect_attempts_; + + BoundNetLog net_log_; + + DISALLOW_COPY_AND_ASSIGN(TCPSocketWin); +}; + +} // namespace net + +#endif // NET_SOCKET_TCP_SOCKET_WIN_H_ + diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc index 8255e988fa4..d03e3e651ac 100644 --- a/chromium/net/socket/transport_client_socket_pool.cc +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -48,25 +48,18 @@ bool AddressListOnlyContainsIPv6(const AddressList& list) { TransportSocketParams::TransportSocketParams( const HostPortPair& host_port_pair, - RequestPriority priority, bool disable_resolver_cache, bool ignore_limits, const OnHostResolutionCallback& host_resolution_callback) : destination_(host_port_pair), ignore_limits_(ignore_limits), host_resolution_callback_(host_resolution_callback) { - Initialize(priority, disable_resolver_cache); -} - -TransportSocketParams::~TransportSocketParams() {} - -void TransportSocketParams::Initialize(RequestPriority priority, - bool disable_resolver_cache) { - destination_.set_priority(priority); if (disable_resolver_cache) destination_.set_allow_cached_response(false); } +TransportSocketParams::~TransportSocketParams() {} + // TransportConnectJobs will time out after this many seconds. Note this is // the total time, including both host resolution and TCP connect() times. // @@ -80,13 +73,14 @@ static const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. TransportConnectJob::TransportConnectJob( const std::string& group_name, + RequestPriority priority, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, ClientSocketFactory* client_socket_factory, HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, delegate, + : ConnectJob(group_name, timeout_duration, priority, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), client_socket_factory_(client_socket_factory), @@ -107,10 +101,11 @@ LoadState TransportConnectJob::GetLoadState() const { case STATE_TRANSPORT_CONNECT: case STATE_TRANSPORT_CONNECT_COMPLETE: return LOAD_STATE_CONNECTING; - default: - NOTREACHED(); + case STATE_NONE: return LOAD_STATE_IDLE; } + NOTREACHED(); + return LOAD_STATE_IDLE; } // static @@ -166,7 +161,9 @@ int TransportConnectJob::DoResolveHost() { connect_timing_.dns_start = base::TimeTicks::Now(); return resolver_.Resolve( - params_->destination(), &addresses_, + params_->destination(), + priority(), + &addresses_, base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this)), net_log()); } @@ -190,8 +187,8 @@ int TransportConnectJob::DoResolveHostComplete(int result) { int TransportConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; - transport_socket_.reset(client_socket_factory_->CreateTransportClientSocket( - addresses_, net_log().net_log(), net_log().source())); + transport_socket_ = client_socket_factory_->CreateTransportClientSocket( + addresses_, net_log().net_log(), net_log().source()); int rv = transport_socket_->Connect( base::Bind(&TransportConnectJob::OnIOComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING && @@ -246,7 +243,7 @@ int TransportConnectJob::DoTransportConnectComplete(int result) { 100); } } - set_socket(transport_socket_.release()); + SetSocket(transport_socket_.Pass()); fallback_timer_.Stop(); } else { // Be a bit paranoid and kill off the fallback members to prevent reuse. @@ -270,9 +267,9 @@ void TransportConnectJob::DoIPv6FallbackTransportConnect() { fallback_addresses_.reset(new AddressList(addresses_)); MakeAddressListStartWithIPv4(fallback_addresses_.get()); - fallback_transport_socket_.reset( + fallback_transport_socket_ = client_socket_factory_->CreateTransportClientSocket( - *fallback_addresses_, net_log().net_log(), net_log().source())); + *fallback_addresses_, net_log().net_log(), net_log().source()); fallback_connect_start_time_ = base::TimeTicks::Now(); int rv = fallback_transport_socket_->Connect( base::Bind( @@ -317,7 +314,7 @@ void TransportConnectJob::DoIPv6FallbackTransportConnectComplete(int result) { base::TimeDelta::FromMilliseconds(1), base::TimeDelta::FromMinutes(10), 100); - set_socket(fallback_transport_socket_.release()); + SetSocket(fallback_transport_socket_.Pass()); next_state_ = STATE_NONE; transport_socket_.reset(); } else { @@ -333,18 +330,20 @@ int TransportConnectJob::ConnectInternal() { return DoLoop(OK); } -ConnectJob* +scoped_ptr<ConnectJob> TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return new TransportConnectJob(group_name, - request.params(), - ConnectionTimeout(), - client_socket_factory_, - host_resolver_, - delegate, - net_log_); + return scoped_ptr<ConnectJob>( + new TransportConnectJob(group_name, + request.priority(), + request.params(), + ConnectionTimeout(), + client_socket_factory_, + host_resolver_, + delegate, + net_log_)); } base::TimeDelta @@ -360,11 +359,11 @@ TransportClientSocketPool::TransportClientSocketPool( HostResolver* host_resolver, ClientSocketFactory* client_socket_factory, NetLog* net_log) - : base_(max_sockets, max_sockets_per_group, histograms, + : base_(NULL, max_sockets, max_sockets_per_group, histograms, ClientSocketPool::unused_idle_socket_timeout(), ClientSocketPool::used_idle_socket_timeout(), new TransportConnectJobFactory(client_socket_factory, - host_resolver, net_log)) { + host_resolver, net_log)) { base_.EnableConnectBackupJobs(); } @@ -419,19 +418,15 @@ void TransportClientSocketPool::CancelRequest( void TransportClientSocketPool::ReleaseSocket( const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) { - base_.ReleaseSocket(group_name, socket, id); + base_.ReleaseSocket(group_name, socket.Pass(), id); } void TransportClientSocketPool::FlushWithError(int error) { base_.FlushWithError(error); } -bool TransportClientSocketPool::IsStalled() const { - return base_.IsStalled(); -} - void TransportClientSocketPool::CloseIdleSockets() { base_.CloseIdleSockets(); } @@ -450,14 +445,6 @@ LoadState TransportClientSocketPool::GetLoadState( return base_.GetLoadState(group_name, handle); } -void TransportClientSocketPool::AddLayeredPool(LayeredPool* layered_pool) { - base_.AddLayeredPool(layered_pool); -} - -void TransportClientSocketPool::RemoveLayeredPool(LayeredPool* layered_pool) { - base_.RemoveLayeredPool(layered_pool); -} - base::DictionaryValue* TransportClientSocketPool::GetInfoAsValue( const std::string& name, const std::string& type, @@ -473,4 +460,18 @@ ClientSocketPoolHistograms* TransportClientSocketPool::histograms() const { return base_.histograms(); } +bool TransportClientSocketPool::IsStalled() const { + return base_.IsStalled(); +} + +void TransportClientSocketPool::AddHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.AddHigherLayeredPool(higher_pool); +} + +void TransportClientSocketPool::RemoveHigherLayeredPool( + HigherLayeredPool* higher_pool) { + base_.RemoveHigherLayeredPool(higher_pool); +} + } // namespace net diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h index bb53b3da301..16e421a4550 100644 --- a/chromium/net/socket/transport_client_socket_pool.h +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -34,7 +34,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams // connection will be aborted with that value. TransportSocketParams( const HostPortPair& host_port_pair, - RequestPriority priority, bool disable_resolver_cache, bool ignore_limits, const OnHostResolutionCallback& host_resolution_callback); @@ -49,8 +48,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams friend class base::RefCounted<TransportSocketParams>; ~TransportSocketParams(); - void Initialize(RequestPriority priority, bool disable_resolver_cache); - HostResolver::RequestInfo destination_; bool ignore_limits_; const OnHostResolutionCallback host_resolution_callback_; @@ -69,6 +66,7 @@ class NET_EXPORT_PRIVATE TransportSocketParams class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { public: TransportConnectJob(const std::string& group_name, + RequestPriority priority, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, ClientSocketFactory* client_socket_factory, @@ -132,6 +130,8 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { public: + typedef TransportSocketParams SocketParams; + TransportClientSocketPool( int max_sockets, int max_sockets_per_group, @@ -156,10 +156,9 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual void CancelRequest(const std::string& group_name, ClientSocketHandle* handle) OVERRIDE; virtual void ReleaseSocket(const std::string& group_name, - StreamSocket* socket, + scoped_ptr<StreamSocket> socket, int id) OVERRIDE; virtual void FlushWithError(int error) OVERRIDE; - virtual bool IsStalled() const OVERRIDE; virtual void CloseIdleSockets() OVERRIDE; virtual int IdleSocketCount() const OVERRIDE; virtual int IdleSocketCountInGroup( @@ -167,8 +166,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual LoadState GetLoadState( const std::string& group_name, const ClientSocketHandle* handle) const OVERRIDE; - virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE; - virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE; virtual base::DictionaryValue* GetInfoAsValue( const std::string& name, const std::string& type, @@ -176,6 +173,11 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { virtual base::TimeDelta ConnectionTimeout() const OVERRIDE; virtual ClientSocketPoolHistograms* histograms() const OVERRIDE; + // HigherLayeredPool implementation. + virtual bool IsStalled() const OVERRIDE; + virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE; + private: typedef ClientSocketPoolBase<TransportSocketParams> PoolBase; @@ -193,7 +195,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { // ClientSocketPoolBase::ConnectJobFactory methods. - virtual ConnectJob* NewConnectJob( + virtual scoped_ptr<ConnectJob> NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const OVERRIDE; @@ -213,9 +215,6 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { DISALLOW_COPY_AND_ASSIGN(TransportClientSocketPool); }; -REGISTER_SOCKET_PARAMS_FOR_POOL(TransportClientSocketPool, - TransportSocketParams); - } // namespace net #endif // NET_SOCKET_TRANSPORT_CLIENT_SOCKET_POOL_H_ diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc index dfa1151b291..a984ea3b740 100644 --- a/chromium/net/socket/transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -23,6 +23,7 @@ #include "net/socket/client_socket_handle.h" #include "net/socket/client_socket_pool_histograms.h" #include "net/socket/socket_test_util.h" +#include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -340,16 +341,16 @@ class MockClientSocketFactory : public ClientSocketFactory { delay_(base::TimeDelta::FromMilliseconds( ClientSocketPool::kMaxConnectRetryIntervalMs)) {} - virtual DatagramClientSocket* CreateDatagramClientSocket( + virtual scoped_ptr<DatagramClientSocket> CreateDatagramClientSocket( DatagramSocket::BindType bind_type, const RandIntCallback& rand_int_cb, NetLog* net_log, const NetLog::Source& source) OVERRIDE { NOTREACHED(); - return NULL; + return scoped_ptr<DatagramClientSocket>(); } - virtual StreamSocket* CreateTransportClientSocket( + virtual scoped_ptr<StreamSocket> CreateTransportClientSocket( const AddressList& addresses, NetLog* /* net_log */, const NetLog::Source& /* source */) OVERRIDE { @@ -363,34 +364,41 @@ class MockClientSocketFactory : public ClientSocketFactory { switch (type) { case MOCK_CLIENT_SOCKET: - return new MockClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); case MOCK_FAILING_CLIENT_SOCKET: - return new MockFailingClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockFailingClientSocket(addresses, net_log_)); case MOCK_PENDING_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, false, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, base::TimeDelta(), net_log_)); case MOCK_PENDING_FAILING_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, false, false, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, false, false, base::TimeDelta(), net_log_)); case MOCK_DELAYED_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, false, delay_, net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, false, delay_, net_log_)); case MOCK_STALLED_CLIENT_SOCKET: - return new MockPendingClientSocket( - addresses, true, true, base::TimeDelta(), net_log_); + return scoped_ptr<StreamSocket>( + new MockPendingClientSocket( + addresses, true, true, base::TimeDelta(), net_log_)); default: NOTREACHED(); - return new MockClientSocket(addresses, net_log_); + return scoped_ptr<StreamSocket>( + new MockClientSocket(addresses, net_log_)); } } - virtual SSLClientSocket* CreateSSLClientSocket( - ClientSocketHandle* transport_socket, + virtual scoped_ptr<SSLClientSocket> CreateSSLClientSocket( + scoped_ptr<ClientSocketHandle> transport_socket, const HostPortPair& host_and_port, const SSLConfig& ssl_config, const SSLClientSocketContext& context) OVERRIDE { NOTIMPLEMENTED(); - return NULL; + return scoped_ptr<SSLClientSocket>(); } virtual void ClearSSLSessionCache() OVERRIDE { @@ -431,11 +439,7 @@ class TransportClientSocketPoolTest : public testing::Test { ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true)), params_( new TransportSocketParams(HostPortPair("www.google.com", 80), - kDefaultPriority, false, false, - OnHostResolutionCallback())), - low_params_( - new TransportSocketParams(HostPortPair("www.google.com", 80), - LOW, false, false, + false, false, OnHostResolutionCallback())), histograms_(new ClientSocketPoolHistograms("TCPUnitTest")), host_resolver_(new MockHostResolver), @@ -455,7 +459,7 @@ class TransportClientSocketPoolTest : public testing::Test { int StartRequest(const std::string& group_name, RequestPriority priority) { scoped_refptr<TransportSocketParams> params(new TransportSocketParams( - HostPortPair("www.google.com", 80), MEDIUM, false, false, + HostPortPair("www.google.com", 80), false, false, OnHostResolutionCallback())); return test_base_.StartRequestUsingPool( &pool_, group_name, priority, params); @@ -479,7 +483,6 @@ class TransportClientSocketPoolTest : public testing::Test { bool connect_backup_jobs_enabled_; CapturingNetLog net_log_; scoped_refptr<TransportSocketParams> params_; - scoped_refptr<TransportSocketParams> low_params_; scoped_ptr<ClientSocketPoolHistograms> histograms_; scoped_ptr<MockHostResolver> host_resolver_; MockClientSocketFactory client_socket_factory_; @@ -561,7 +564,7 @@ TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) { TEST_F(TransportClientSocketPoolTest, Basic) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -573,13 +576,27 @@ TEST_F(TransportClientSocketPoolTest, Basic) { TestLoadTimingInfoConnectedNotReused(handle); } +// Make sure that TransportConnectJob passes on its priority to its +// HostResolver request on Init. +TEST_F(TransportClientSocketPoolTest, SetResolvePriorityOnInit) { + for (int i = MINIMUM_PRIORITY; i < NUM_PRIORITIES; ++i) { + RequestPriority priority = static_cast<RequestPriority>(i); + TestCompletionCallback callback; + ClientSocketHandle handle; + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, priority, callback.callback(), &pool_, + BoundNetLog())); + EXPECT_EQ(priority, host_resolver_->last_request_priority()); + } +} + TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { host_resolver_->rules()->AddSimulatedFailure("unresolvable.host.name"); TestCompletionCallback callback; ClientSocketHandle handle; HostPortPair host_port_pair("unresolvable.host.name", 80); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - host_port_pair, kDefaultPriority, false, false, + host_port_pair, false, false, OnHostResolutionCallback())); EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", dest, kDefaultPriority, callback.callback(), @@ -854,7 +871,7 @@ class RequestSocketCallback : public TestCompletionCallbackBase { } within_callback_ = true; scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), LOWEST, false, false, + HostPortPair("www.google.com", 80), false, false, OnHostResolutionCallback())); int rv = handle_->Init("a", dest, LOWEST, callback(), pool_, BoundNetLog()); @@ -874,7 +891,7 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) { ClientSocketHandle handle; RequestSocketCallback callback(&handle, &pool_); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), LOWEST, false, false, + HostPortPair("www.google.com", 80), false, false, OnHostResolutionCallback())); int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_, BoundNetLog()); @@ -939,7 +956,7 @@ TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -957,7 +974,7 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { // Now we should have 1 idle socket. EXPECT_EQ(1, pool_.IdleSocketCount()); - rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_EQ(0, pool_.IdleSocketCount()); @@ -967,7 +984,7 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1023,7 +1040,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1065,7 +1082,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("c", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("c", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1111,7 +1128,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1159,7 +1176,7 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", low_params_, LOW, callback.callback(), &pool_, + int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1215,7 +1232,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1260,7 +1277,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1294,7 +1311,7 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); @@ -1327,7 +1344,7 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", low_params_, LOW, callback.callback(), &pool, + int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc index 2f75e740067..5548b27b995 100644 --- a/chromium/net/socket/transport_client_socket_unittest.cc +++ b/chromium/net/socket/transport_client_socket_unittest.cc @@ -48,8 +48,9 @@ class TransportClientSocketTest // Implement StreamListenSocket::Delegate methods virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE { - connected_sock_ = reinterpret_cast<TCPListenSocket*>(connection); + scoped_ptr<StreamListenSocket> connection) OVERRIDE { + connected_sock_.reset( + static_cast<TCPListenSocket*>(connection.release())); } virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE { // TODO(dkegel): this might not be long enough to tickle some bugs. @@ -65,7 +66,7 @@ class TransportClientSocketTest void CloseServerSocket() { // delete the connected_sock_, which will close it. - connected_sock_ = NULL; + connected_sock_.reset(); } void PauseServerReads() { @@ -94,8 +95,8 @@ class TransportClientSocketTest scoped_ptr<StreamSocket> sock_; private: - scoped_refptr<TCPListenSocket> listen_sock_; - scoped_refptr<TCPListenSocket> connected_sock_; + scoped_ptr<TCPListenSocket> listen_sock_; + scoped_ptr<TCPListenSocket> connected_sock_; bool close_server_socket_on_next_send_; }; @@ -103,7 +104,7 @@ void TransportClientSocketTest::SetUp() { ::testing::TestWithParam<ClientSocketTestTypes>::SetUp(); // Find a free port to listen on - scoped_refptr<TCPListenSocket> sock; + scoped_ptr<TCPListenSocket> sock; int port; // Range of ports to listen on. Shouldn't need to try many. const int kMinPort = 10100; @@ -117,7 +118,7 @@ void TransportClientSocketTest::SetUp() { break; } ASSERT_TRUE(sock.get() != NULL); - listen_sock_ = sock; + listen_sock_ = sock.Pass(); listen_port_ = port; AddressList addr; @@ -125,15 +126,15 @@ void TransportClientSocketTest::SetUp() { scoped_ptr<HostResolver> resolver(new MockHostResolver()); HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_)); TestCompletionCallback callback; - int rv = resolver->Resolve(info, &addr, callback.callback(), NULL, - BoundNetLog()); + int rv = resolver->Resolve( + info, DEFAULT_PRIORITY, &addr, callback.callback(), NULL, BoundNetLog()); CHECK_EQ(ERR_IO_PENDING, rv); rv = callback.WaitForResult(); CHECK_EQ(rv, OK); - sock_.reset( + sock_ = socket_factory_->CreateTransportClientSocket(addr, &net_log_, - NetLog::Source())); + NetLog::Source()); } int TransportClientSocketTest::DrainClientSocket( diff --git a/chromium/net/socket/unix_domain_socket_posix.cc b/chromium/net/socket/unix_domain_socket_posix.cc index 5b6b2498245..2b781d58b35 100644 --- a/chromium/net/socket/unix_domain_socket_posix.cc +++ b/chromium/net/socket/unix_domain_socket_posix.cc @@ -21,6 +21,7 @@ #include "build/build_config.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" namespace net { @@ -48,12 +49,12 @@ bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) { } // namespace // static -UnixDomainSocket::AuthCallback NoAuthentication() { +UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() { return base::Bind(NoAuthenticationCallback); } // static -UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal( +scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, @@ -63,14 +64,15 @@ UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal( if (s == kInvalidSocket && !fallback_path.empty()) s = CreateAndBind(fallback_path, use_abstract_namespace); if (s == kInvalidSocket) - return NULL; - UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback); + return scoped_ptr<UnixDomainSocket>(); + scoped_ptr<UnixDomainSocket> sock( + new UnixDomainSocket(s, del, auth_callback)); sock->Listen(); - return sock; + return sock.Pass(); } // static -scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( +scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( const std::string& path, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback) { @@ -79,14 +81,14 @@ scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen( #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) // static -scoped_refptr<UnixDomainSocket> +scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenWithAbstractNamespace( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback) { - return make_scoped_refptr( - CreateAndListenInternal(path, fallback_path, del, auth_callback, true)); + return + CreateAndListenInternal(path, fallback_path, del, auth_callback, true); } #endif @@ -106,7 +108,7 @@ SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path, static const size_t kPathMax = sizeof(addr.sun_path); if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax) return kInvalidSocket; - const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0); + const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); if (s == kInvalidSocket) return kInvalidSocket; memset(&addr, 0, sizeof(addr)); @@ -147,11 +149,11 @@ void UnixDomainSocket::Accept() { LOG(ERROR) << "close() error"; return; } - scoped_refptr<UnixDomainSocket> sock( + scoped_ptr<UnixDomainSocket> sock( new UnixDomainSocket(conn, socket_delegate_, auth_callback_)); // It's up to the delegate to AddRef if it wants to keep it around. sock->WatchSocket(WAITING_READ); - socket_delegate_->DidAccept(this, sock.get()); + socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>()); } UnixDomainSocketFactory::UnixDomainSocketFactory( @@ -162,10 +164,10 @@ UnixDomainSocketFactory::UnixDomainSocketFactory( UnixDomainSocketFactory::~UnixDomainSocketFactory() {} -scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( +scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { return UnixDomainSocket::CreateAndListen( - path_, delegate, auth_callback_); + path_, delegate, auth_callback_).PassAs<StreamListenSocket>(); } #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED) @@ -181,11 +183,12 @@ UnixDomainSocketWithAbstractNamespaceFactory( UnixDomainSocketWithAbstractNamespaceFactory:: ~UnixDomainSocketWithAbstractNamespaceFactory() {} -scoped_refptr<StreamListenSocket> +scoped_ptr<StreamListenSocket> UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen( StreamListenSocket::Delegate* delegate) const { return UnixDomainSocket::CreateAndListenWithAbstractNamespace( - path_, fallback_path_, delegate, auth_callback_); + path_, fallback_path_, delegate, auth_callback_) + .PassAs<StreamListenSocket>(); } #endif diff --git a/chromium/net/socket/unix_domain_socket_posix.h b/chromium/net/socket/unix_domain_socket_posix.h index 2ef06803d24..98d0c11a648 100644 --- a/chromium/net/socket/unix_domain_socket_posix.h +++ b/chromium/net/socket/unix_domain_socket_posix.h @@ -10,7 +10,6 @@ #include "base/basictypes.h" #include "base/callback_forward.h" #include "base/compiler_specific.h" -#include "base/memory/ref_counted.h" #include "build/build_config.h" #include "net/base/net_export.h" #include "net/socket/stream_listen_socket.h" @@ -26,6 +25,8 @@ namespace net { // Unix Domain Socket Implementation. Supports abstract namespaces on Linux. class NET_EXPORT UnixDomainSocket : public StreamListenSocket { public: + virtual ~UnixDomainSocket(); + // Callback that returns whether the already connected client, identified by // its process |user_id| and |group_id|, is allowed to keep the connection // open. Note that the socket is closed immediately in case the callback @@ -38,7 +39,7 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket { // Note that the returned UnixDomainSocket instance does not take ownership of // |del|. - static scoped_refptr<UnixDomainSocket> CreateAndListen( + static scoped_ptr<UnixDomainSocket> CreateAndListen( const std::string& path, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback); @@ -47,7 +48,7 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket { // Same as above except that the created socket uses the abstract namespace // which is a Linux-only feature. If |fallback_path| is not empty, // make the second attempt with the provided fallback name. - static scoped_refptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( + static scoped_ptr<UnixDomainSocket> CreateAndListenWithAbstractNamespace( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, @@ -58,9 +59,8 @@ class NET_EXPORT UnixDomainSocket : public StreamListenSocket { UnixDomainSocket(SocketDescriptor s, StreamListenSocket::Delegate* del, const AuthCallback& auth_callback); - virtual ~UnixDomainSocket(); - static UnixDomainSocket* CreateAndListenInternal( + static scoped_ptr<UnixDomainSocket> CreateAndListenInternal( const std::string& path, const std::string& fallback_path, StreamListenSocket::Delegate* del, @@ -87,7 +87,7 @@ class NET_EXPORT UnixDomainSocketFactory : public StreamListenSocketFactory { virtual ~UnixDomainSocketFactory(); // StreamListenSocketFactory: - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const OVERRIDE; protected: @@ -111,7 +111,7 @@ class NET_EXPORT UnixDomainSocketWithAbstractNamespaceFactory virtual ~UnixDomainSocketWithAbstractNamespaceFactory(); // UnixDomainSocketFactory: - virtual scoped_refptr<StreamListenSocket> CreateAndListen( + virtual scoped_ptr<StreamListenSocket> CreateAndListen( StreamListenSocket::Delegate* delegate) const OVERRIDE; private: diff --git a/chromium/net/socket/unix_domain_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_socket_posix_unittest.cc index 5abe03b4ae3..f062d274205 100644 --- a/chromium/net/socket/unix_domain_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_socket_posix_unittest.cc @@ -29,6 +29,7 @@ #include "base/synchronization/lock.h" #include "base/threading/platform_thread.h" #include "base/threading/thread.h" +#include "net/socket/socket_descriptor.h" #include "net/socket/unix_domain_socket_posix.h" #include "testing/gtest/include/gtest/gtest.h" @@ -102,9 +103,9 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { : event_manager_(event_manager) {} virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE { + scoped_ptr<StreamListenSocket> connection) OVERRIDE { LOG(ERROR) << __PRETTY_FUNCTION__; - connection_ = connection; + connection_ = connection.Pass(); Notify(EVENT_ACCEPT); } @@ -138,7 +139,7 @@ class TestListenSocketDelegate : public StreamListenSocket::Delegate { } const scoped_refptr<EventManager> event_manager_; - scoped_refptr<StreamListenSocket> connection_; + scoped_ptr<StreamListenSocket> connection_; base::Lock mutex_; string data_; }; @@ -172,7 +173,7 @@ class UnixDomainSocketTestHelper : public testing::Test { virtual void TearDown() OVERRIDE { DeleteSocketFile(); - socket_ = NULL; + socket_.reset(); socket_delegate_.reset(); event_manager_ = NULL; } @@ -187,10 +188,10 @@ class UnixDomainSocketTestHelper : public testing::Test { } SocketDescriptor CreateClientSocket() { - const SocketDescriptor sock = socket(PF_UNIX, SOCK_STREAM, 0); + const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); if (sock < 0) { LOG(ERROR) << "socket() error"; - return StreamListenSocket::kInvalidSocket; + return kInvalidSocket; } sockaddr_un addr; memset(&addr, 0, sizeof(addr)); @@ -200,7 +201,7 @@ class UnixDomainSocketTestHelper : public testing::Test { addr_len = sizeof(sockaddr_un); if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { LOG(ERROR) << "connect() error"; - return StreamListenSocket::kInvalidSocket; + return kInvalidSocket; } return sock; } @@ -221,7 +222,7 @@ class UnixDomainSocketTestHelper : public testing::Test { const bool allow_user_; scoped_refptr<EventManager> event_manager_; scoped_ptr<TestListenSocketDelegate> socket_delegate_; - scoped_refptr<UnixDomainSocket> socket_; + scoped_ptr<UnixDomainSocket> socket_; }; class UnixDomainSocketTest : public UnixDomainSocketTestHelper { @@ -264,7 +265,7 @@ TEST_F(UnixDomainSocketTestWithInvalidPath, } TEST_F(UnixDomainSocketTest, TestFallbackName) { - scoped_refptr<UnixDomainSocket> existing_socket = + scoped_ptr<UnixDomainSocket> existing_socket = UnixDomainSocket::CreateAndListenWithAbstractNamespace( file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(existing_socket.get() == NULL); @@ -280,7 +281,6 @@ TEST_F(UnixDomainSocketTest, TestFallbackName) { socket_delegate_.get(), MakeAuthCallback()); EXPECT_FALSE(socket_.get() == NULL); - existing_socket = NULL; } #endif @@ -291,7 +291,7 @@ TEST_F(UnixDomainSocketTest, TestWithClient) { // Create the client socket. const SocketDescriptor sock = CreateClientSocket(); - ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + ASSERT_NE(kInvalidSocket, sock); event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_AUTH_GRANTED, event); event = event_manager_->WaitForEvent(); @@ -316,7 +316,7 @@ TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { EventType event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_LISTEN, event); const SocketDescriptor sock = CreateClientSocket(); - ASSERT_NE(StreamListenSocket::kInvalidSocket, sock); + ASSERT_NE(kInvalidSocket, sock); event = event_manager_->WaitForEvent(); ASSERT_EQ(EVENT_AUTH_DENIED, event); diff --git a/chromium/net/socket_stream/socket_stream.cc b/chromium/net/socket_stream/socket_stream.cc index dca994cbada..699c8204aea 100644 --- a/chromium/net/socket_stream/socket_stream.cc +++ b/chromium/net/socket_stream/socket_stream.cc @@ -33,6 +33,7 @@ #include "net/http/http_transaction_factory.h" #include "net/http/http_util.h" #include "net/socket/client_socket_factory.h" +#include "net/socket/client_socket_handle.h" #include "net/socket/socks5_client_socket.h" #include "net/socket/socks_client_socket.h" #include "net/socket/ssl_client_socket.h" @@ -96,6 +97,7 @@ SocketStream::SocketStream(const GURL& url, Delegate* delegate) proxy_mode_(kDirectConnection), proxy_url_(url), pac_request_(NULL), + connection_(new ClientSocketHandle), privacy_mode_(kPrivacyModeDisabled), // Unretained() is required; without it, Bind() creates a circular // dependency and the SocketStream object will not be freed. @@ -205,8 +207,10 @@ bool SocketStream::SendData(const char* data, int len) { << "The current base::MessageLoop must be TYPE_IO"; DCHECK_GT(len, 0); - if (!socket_.get() || !socket_->IsConnected() || next_state_ == STATE_NONE) + if (!connection_->socket() || + !connection_->socket()->IsConnected() || next_state_ == STATE_NONE) { return false; + } int total_buffered_bytes = len; if (current_write_buf_.get()) { @@ -264,7 +268,7 @@ void SocketStream::RestartWithAuth(const AuthCredentials& credentials) { DCHECK_EQ(base::MessageLoop::TYPE_IO, base::MessageLoop::current()->type()) << "The current base::MessageLoop must be TYPE_IO"; DCHECK(proxy_auth_controller_.get()); - if (!socket_.get()) { + if (!connection_->socket()) { DVLOG(1) << "Socket is closed before restarting with auth."; return; } @@ -369,7 +373,7 @@ void SocketStream::Finish(int result) { } int SocketStream::DidEstablishConnection() { - if (!socket_.get() || !socket_->IsConnected()) { + if (!connection_->socket() || !connection_->socket()->IsConnected()) { next_state_ = STATE_CLOSE; return ERR_CONNECTION_FAILED; } @@ -675,9 +679,11 @@ int SocketStream::DoResolveHost() { DCHECK(context_->host_resolver()); resolver_.reset(new SingleRequestHostResolver(context_->host_resolver())); - return resolver_->Resolve( - resolve_info, &addresses_, base::Bind(&SocketStream::OnIOCompleted, this), - net_log_); + return resolver_->Resolve(resolve_info, + DEFAULT_PRIORITY, + &addresses_, + base::Bind(&SocketStream::OnIOCompleted, this), + net_log_); } int SocketStream::DoResolveHostComplete(int result) { @@ -730,11 +736,12 @@ int SocketStream::DoTcpConnect(int result) { } next_state_ = STATE_TCP_CONNECT_COMPLETE; DCHECK(factory_); - socket_.reset(factory_->CreateTransportClientSocket(addresses_, - net_log_.net_log(), - net_log_.source())); + connection_->SetSocket( + factory_->CreateTransportClientSocket(addresses_, + net_log_.net_log(), + net_log_.source())); metrics_->OnStartConnection(); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoTcpConnectComplete(int result) { @@ -819,7 +826,8 @@ int SocketStream::DoWriteTunnelHeaders() { int buf_len = static_cast<int>(tunnel_request_headers_->headers_.size() - tunnel_request_headers_bytes_sent_); DCHECK_GT(buf_len, 0); - return socket_->Write(tunnel_request_headers_.get(), buf_len, io_callback_); + return connection_->socket()->Write( + tunnel_request_headers_.get(), buf_len, io_callback_); } int SocketStream::DoWriteTunnelHeadersComplete(int result) { @@ -862,7 +870,8 @@ int SocketStream::DoReadTunnelHeaders() { tunnel_response_headers_->SetDataOffset(tunnel_response_headers_len_); CHECK(tunnel_response_headers_->data()); - return socket_->Read(tunnel_response_headers_.get(), buf_len, io_callback_); + return connection_->socket()->Read( + tunnel_response_headers_.get(), buf_len, io_callback_); } int SocketStream::DoReadTunnelHeadersComplete(int result) { @@ -953,17 +962,22 @@ int SocketStream::DoSOCKSConnect() { next_state_ = STATE_SOCKS_CONNECT_COMPLETE; - StreamSocket* s = socket_.release(); HostResolver::RequestInfo req_info(HostPortPair::FromURL(url_)); DCHECK(!proxy_info_.is_empty()); - if (proxy_info_.proxy_server().scheme() == ProxyServer::SCHEME_SOCKS5) - s = new SOCKS5ClientSocket(s, req_info); - else - s = new SOCKSClientSocket(s, req_info, context_->host_resolver()); - socket_.reset(s); + scoped_ptr<StreamSocket> s; + if (proxy_info_.proxy_server().scheme() == ProxyServer::SCHEME_SOCKS5) { + s.reset(new SOCKS5ClientSocket(connection_.Pass(), req_info)); + } else { + s.reset(new SOCKSClientSocket(connection_.Pass(), + req_info, + DEFAULT_PRIORITY, + context_->host_resolver())); + } + connection_.reset(new ClientSocketHandle); + connection_->SetSocket(s.Pass()); metrics_->OnCountConnectionType(SocketStreamMetrics::SOCKS_CONNECTION); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoSOCKSConnectComplete(int result) { @@ -986,14 +1000,16 @@ int SocketStream::DoSecureProxyConnect() { ssl_context.cert_verifier = context_->cert_verifier(); ssl_context.transport_security_state = context_->transport_security_state(); ssl_context.server_bound_cert_service = context_->server_bound_cert_service(); - socket_.reset(factory_->CreateSSLClientSocket( - socket_.release(), + scoped_ptr<StreamSocket> socket(factory_->CreateSSLClientSocket( + connection_.Pass(), proxy_info_.proxy_server().host_port_pair(), proxy_ssl_config_, ssl_context)); + connection_.reset(new ClientSocketHandle); + connection_->SetSocket(socket.Pass()); next_state_ = STATE_SECURE_PROXY_CONNECT_COMPLETE; metrics_->OnCountConnectionType(SocketStreamMetrics::SECURE_PROXY_CONNECTION); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoSecureProxyConnectComplete(int result) { @@ -1025,7 +1041,7 @@ int SocketStream::DoSecureProxyHandleCertError(int result) { int SocketStream::DoSecureProxyHandleCertErrorComplete(int result) { DCHECK_EQ(STATE_NONE, next_state_); if (result == OK) { - if (!socket_->IsConnectedAndIdle()) + if (!connection_->socket()->IsConnectedAndIdle()) return AllowCertErrorForReconnection(&proxy_ssl_config_); next_state_ = STATE_GENERATE_PROXY_AUTH_TOKEN; } else { @@ -1040,13 +1056,16 @@ int SocketStream::DoSSLConnect() { ssl_context.cert_verifier = context_->cert_verifier(); ssl_context.transport_security_state = context_->transport_security_state(); ssl_context.server_bound_cert_service = context_->server_bound_cert_service(); - socket_.reset(factory_->CreateSSLClientSocket(socket_.release(), - HostPortPair::FromURL(url_), - server_ssl_config_, - ssl_context)); + scoped_ptr<StreamSocket> socket( + factory_->CreateSSLClientSocket(connection_.Pass(), + HostPortPair::FromURL(url_), + server_ssl_config_, + ssl_context)); + connection_.reset(new ClientSocketHandle); + connection_->SetSocket(socket.Pass()); next_state_ = STATE_SSL_CONNECT_COMPLETE; metrics_->OnCountConnectionType(SocketStreamMetrics::SSL_CONNECTION); - return socket_->Connect(io_callback_); + return connection_->socket()->Connect(io_callback_); } int SocketStream::DoSSLConnectComplete(int result) { @@ -1082,7 +1101,7 @@ int SocketStream::DoSSLHandleCertErrorComplete(int result) { // we should take care of TLS NPN extension here. if (result == OK) { - if (!socket_->IsConnectedAndIdle()) + if (!connection_->socket()->IsConnectedAndIdle()) return AllowCertErrorForReconnection(&server_ssl_config_); result = DidEstablishConnection(); } else { @@ -1096,7 +1115,7 @@ int SocketStream::DoReadWrite(int result) { next_state_ = STATE_CLOSE; return result; } - if (!socket_.get() || !socket_->IsConnected()) { + if (!connection_->socket() || !connection_->socket()->IsConnected()) { next_state_ = STATE_CLOSE; return ERR_CONNECTION_CLOSED; } @@ -1105,7 +1124,7 @@ int SocketStream::DoReadWrite(int result) { // let's close the socket. // We don't care about receiving data after the socket is closed. if (closing_ && !current_write_buf_.get() && pending_write_bufs_.empty()) { - socket_->Disconnect(); + connection_->socket()->Disconnect(); next_state_ = STATE_CLOSE; return OK; } @@ -1117,7 +1136,7 @@ int SocketStream::DoReadWrite(int result) { if (!read_buf_.get()) { // No read pending and server didn't close the socket. read_buf_ = new IOBuffer(kReadBufferSize); - result = socket_->Read( + result = connection_->socket()->Read( read_buf_.get(), kReadBufferSize, base::Bind(&SocketStream::OnReadCompleted, base::Unretained(this))); @@ -1156,7 +1175,7 @@ int SocketStream::DoReadWrite(int result) { pending_write_bufs_.pop_front(); } - result = socket_->Write( + result = connection_->socket()->Write( current_write_buf_.get(), current_write_buf_->BytesRemaining(), base::Bind(&SocketStream::OnWriteCompleted, base::Unretained(this))); @@ -1188,10 +1207,10 @@ int SocketStream::HandleCertificateRequest(int result, SSLConfig* ssl_config) { return result; } - DCHECK(socket_.get()); + DCHECK(connection_->socket()); scoped_refptr<SSLCertRequestInfo> cert_request_info = new SSLCertRequestInfo; SSLClientSocket* ssl_socket = - static_cast<SSLClientSocket*>(socket_.get()); + static_cast<SSLClientSocket*>(connection_->socket()); ssl_socket->GetSSLCertRequestInfo(cert_request_info.get()); HttpTransactionFactory* factory = context_->http_transaction_factory(); @@ -1237,7 +1256,8 @@ int SocketStream::AllowCertErrorForReconnection(SSLConfig* ssl_config) { // allowed bad certificates in |ssl_config|. // See also net/http/http_network_transaction.cc HandleCertificateError() and // RestartIgnoringLastError(). - SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(socket_.get()); + SSLClientSocket* ssl_socket = + static_cast<SSLClientSocket*>(connection_->socket()); SSLInfo ssl_info; ssl_socket->GetSSLInfo(&ssl_info); if (ssl_info.cert.get() == NULL || @@ -1259,8 +1279,8 @@ int SocketStream::AllowCertErrorForReconnection(SSLConfig* ssl_config) { bad_cert.cert_status = ssl_info.cert_status; ssl_config->allowed_bad_certs.push_back(bad_cert); // Restart connection ignoring the bad certificate. - socket_->Disconnect(); - socket_.reset(); + connection_->socket()->Disconnect(); + connection_->SetSocket(scoped_ptr<StreamSocket>()); next_state_ = STATE_TCP_CONNECT; return OK; } @@ -1286,7 +1306,8 @@ void SocketStream::DoRestartWithAuth() { int SocketStream::HandleCertificateError(int result) { DCHECK(IsCertificateError(result)); - SSLClientSocket* ssl_socket = static_cast<SSLClientSocket*>(socket_.get()); + SSLClientSocket* ssl_socket = + static_cast<SSLClientSocket*>(connection_->socket()); DCHECK(ssl_socket); if (!context_) diff --git a/chromium/net/socket_stream/socket_stream.h b/chromium/net/socket_stream/socket_stream.h index 5004060df9a..90aeb8c54b9 100644 --- a/chromium/net/socket_stream/socket_stream.h +++ b/chromium/net/socket_stream/socket_stream.h @@ -28,13 +28,13 @@ namespace net { class AuthChallengeInfo; class CertVerifier; class ClientSocketFactory; +class ClientSocketHandle; class CookieOptions; class HostResolver; class HttpAuthController; class SSLInfo; class ServerBoundCertService; class SingleRequestHostResolver; -class StreamSocket; class SocketStreamMetrics; class TransportSecurityState; class URLRequestContext; @@ -364,7 +364,7 @@ class NET_EXPORT SocketStream scoped_ptr<SingleRequestHostResolver> resolver_; AddressList addresses_; - scoped_ptr<StreamSocket> socket_; + scoped_ptr<ClientSocketHandle> connection_; SSLConfig server_ssl_config_; SSLConfig proxy_ssl_config_; diff --git a/chromium/net/spdy/spdy_credential_builder_unittest.cc b/chromium/net/spdy/spdy_credential_builder_unittest.cc index bc67cc593ae..84aff7a60c4 100644 --- a/chromium/net/spdy/spdy_credential_builder_unittest.cc +++ b/chromium/net/spdy/spdy_credential_builder_unittest.cc @@ -4,7 +4,7 @@ #include "net/spdy/spdy_credential_builder.h" -#include "base/threading/sequenced_worker_pool.h" +#include "base/message_loop/message_loop_proxy.h" #include "crypto/ec_private_key.h" #include "crypto/ec_signature_creator.h" #include "net/cert/asn1_util.h" @@ -23,21 +23,16 @@ const static char kSecretPrefix[] = void CreateCertAndKey(std::string* cert, std::string* key) { // TODO(rch): Share this code with ServerBoundCertServiceTest. - scoped_refptr<base::SequencedWorkerPool> sequenced_worker_pool = - new base::SequencedWorkerPool(1, "CreateCertAndKey"); scoped_ptr<ServerBoundCertService> server_bound_cert_service( new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), - sequenced_worker_pool)); + base::MessageLoopProxy::current())); TestCompletionCallback callback; ServerBoundCertService::RequestHandle request_handle; - int rv = server_bound_cert_service->GetDomainBoundCert( - "www.google.com", key, cert, - callback.callback(), &request_handle); + int rv = server_bound_cert_service->GetOrCreateDomainBoundCert( + "www.google.com", key, cert, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback.WaitForResult()); - - sequenced_worker_pool->Shutdown(); } } // namespace diff --git a/chromium/net/spdy/spdy_http_stream.cc b/chromium/net/spdy/spdy_http_stream.cc index 08e8b58bf61..4d9117514ab 100644 --- a/chromium/net/spdy/spdy_http_stream.cc +++ b/chromium/net/spdy/spdy_http_stream.cc @@ -524,4 +524,9 @@ void SpdyHttpStream::Drain(HttpNetworkSession* session) { delete this; } +void SpdyHttpStream::SetPriority(RequestPriority priority) { + // TODO(akalin): Plumb this through to |stream_request_| and + // |stream_|. +} + } // namespace net diff --git a/chromium/net/spdy/spdy_http_stream.h b/chromium/net/spdy/spdy_http_stream.h index 65d98784181..f6e39cd102e 100644 --- a/chromium/net/spdy/spdy_http_stream.h +++ b/chromium/net/spdy/spdy_http_stream.h @@ -73,6 +73,7 @@ class NET_EXPORT_PRIVATE SpdyHttpStream : public SpdyStream::Delegate, SSLCertRequestInfo* cert_request_info) OVERRIDE; virtual bool IsSpdyHttpStream() const OVERRIDE; virtual void Drain(HttpNetworkSession* session) OVERRIDE; + virtual void SetPriority(RequestPriority priority) OVERRIDE; // SpdyStream::Delegate implementation. virtual void OnRequestHeadersSent() OVERRIDE; diff --git a/chromium/net/spdy/spdy_http_stream_unittest.cc b/chromium/net/spdy/spdy_http_stream_unittest.cc index 55387cf410a..7388dc01ddf 100644 --- a/chromium/net/spdy/spdy_http_stream_unittest.cc +++ b/chromium/net/spdy/spdy_http_stream_unittest.cc @@ -7,8 +7,8 @@ #include <vector> #include "base/memory/scoped_ptr.h" +#include "base/message_loop/message_loop_proxy.h" #include "base/stl_util.h" -#include "base/threading/sequenced_worker_pool.h" #include "crypto/ec_private_key.h" #include "crypto/ec_signature_creator.h" #include "crypto/signature_creator.h" @@ -542,7 +542,7 @@ void GetECServerBoundCertAndProof( TestCompletionCallback callback; std::string key; ServerBoundCertService::RequestHandle request_handle; - int rv = server_bound_cert_service->GetDomainBoundCert( + int rv = server_bound_cert_service->GetOrCreateDomainBoundCert( host, &key, cert, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, rv); @@ -783,11 +783,9 @@ TEST_P(SpdyHttpStreamTest, SendCredentialsEC) { if (GetParam() < kProtoSPDY3) return; - scoped_refptr<base::SequencedWorkerPool> sequenced_worker_pool = - new base::SequencedWorkerPool(1, "SpdyHttpStreamSpdy3Test"); scoped_ptr<ServerBoundCertService> server_bound_cert_service( new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), - sequenced_worker_pool)); + base::MessageLoopProxy::current())); std::string cert; std::string proof; GetECServerBoundCertAndProof("www.gmail.com", @@ -795,19 +793,15 @@ TEST_P(SpdyHttpStreamTest, SendCredentialsEC) { &cert, &proof); TestSendCredentials(server_bound_cert_service.get(), cert, proof); - - sequenced_worker_pool->Shutdown(); } TEST_P(SpdyHttpStreamTest, DontSendCredentialsForHttpUrlsEC) { if (GetParam() < kProtoSPDY3) return; - scoped_refptr<base::SequencedWorkerPool> sequenced_worker_pool = - new base::SequencedWorkerPool(1, "SpdyHttpStreamSpdy3Test"); scoped_ptr<ServerBoundCertService> server_bound_cert_service( new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), - sequenced_worker_pool)); + base::MessageLoopProxy::current())); std::string cert; std::string proof; GetECServerBoundCertAndProof("proxy.google.com", @@ -887,7 +881,6 @@ TEST_P(SpdyHttpStreamTest, DontSendCredentialsForHttpUrlsEC) { ASSERT_TRUE(response.headers.get() != NULL); ASSERT_EQ(200, response.headers->response_code()); deterministic_data_->RunFor(1); - sequenced_worker_pool->Shutdown(); } #endif // !defined(USE_OPENSSL) diff --git a/chromium/net/spdy/spdy_network_transaction_unittest.cc b/chromium/net/spdy/spdy_network_transaction_unittest.cc index 98c4d134b4a..f3f9f344995 100644 --- a/chromium/net/spdy/spdy_network_transaction_unittest.cc +++ b/chromium/net/spdy/spdy_network_transaction_unittest.cc @@ -10,6 +10,7 @@ #include "base/file_util.h" #include "base/files/scoped_temp_dir.h" #include "base/memory/scoped_vector.h" +#include "base/run_loop.h" #include "base/stl_util.h" #include "net/base/auth.h" #include "net/base/net_log_unittest.h" @@ -80,6 +81,13 @@ class SpdyNetworkTransactionTest SpdyNetworkTransactionTest() : spdy_util_(GetParam().protocol) { } + virtual ~SpdyNetworkTransactionTest() { + // UploadDataStream posts deletion tasks back to the message loop on + // destruction. + upload_data_stream_.reset(); + base::RunLoop().RunUntilIdle(); + } + virtual void SetUp() { google_get_request_initialized_ = false; google_post_request_initialized_ = false; @@ -87,11 +95,6 @@ class SpdyNetworkTransactionTest ASSERT_TRUE(temp_dir_.CreateUniqueTempDir()); } - virtual void TearDown() { - // Empty the current queue. - base::MessageLoop::current()->RunUntilIdle(); - } - struct TransactionHelperResult { int rv; std::string status_line; @@ -519,7 +522,7 @@ class SpdyNetworkTransactionTest // reads until we complete our callback. while (!callback.have_result()) { data->CompleteRead(); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); } rv = callback.WaitForResult(); } else if (rv <= 0) { @@ -573,7 +576,7 @@ class SpdyNetworkTransactionTest rv = trans2->Start( &CreateGetPushRequest(), callback.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // The data for the pushed path may be coming in more than 1 frame. Compile // the results into a single string. @@ -1880,10 +1883,10 @@ TEST_P(SpdyNetworkTransactionTest, DelayedChunkedPost) { helper.AddData(&data); ASSERT_TRUE(helper.StartDefaultTest()); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); helper.request().upload_data_stream->AppendChunk( kUploadData, kUploadDataSize, false); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); helper.request().upload_data_stream->AppendChunk( kUploadData, kUploadDataSize, true); @@ -2252,7 +2255,7 @@ TEST_P(SpdyNetworkTransactionTest, CancelledTransaction) { // Flush the MessageLoop while the SpdySessionDependencies (in particular, the // MockClientSocketFactory) are still alive. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); helper.VerifyDataNotConsumed(); } @@ -2408,7 +2411,7 @@ TEST_P(SpdyNetworkTransactionTest, DeleteSessionOnReadCallback) { data.CompleteRead(); // Finish running rest of tasks. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); helper.VerifyDataConsumed(); } @@ -2469,12 +2472,12 @@ TEST_P(SpdyNetworkTransactionTest, RedirectGetRequest) { d.set_quit_on_redirect(true); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.received_redirect_count()); r.FollowDeferredRedirect(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); EXPECT_EQ(net::URLRequestStatus::SUCCESS, r.status().status()); @@ -2559,7 +2562,7 @@ TEST_P(SpdyNetworkTransactionTest, RedirectServerPush) { AddSocketDataProvider(&data); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, d.received_redirect_count()); std::string contents("hello!"); @@ -2572,11 +2575,11 @@ TEST_P(SpdyNetworkTransactionTest, RedirectServerPush) { d2.set_quit_on_redirect(true); r2.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d2.received_redirect_count()); r2.FollowDeferredRedirect(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d2.response_started_count()); EXPECT_FALSE(d2.received_data_before_response()); EXPECT_EQ(net::URLRequestStatus::SUCCESS, r2.status().status()); @@ -3786,7 +3789,7 @@ TEST_P(SpdyNetworkTransactionTest, BufferFull) { // Flush the MessageLoop while the SpdySessionDependencies (in particular, the // MockClientSocketFactory) are still alive. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Verify that we consumed all test data. helper.VerifyDataConsumed(); @@ -3883,7 +3886,7 @@ TEST_P(SpdyNetworkTransactionTest, Buffering) { // Flush the MessageLoop while the SpdySessionDependencies (in particular, the // MockClientSocketFactory) are still alive. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Verify that we consumed all test data. helper.VerifyDataConsumed(); @@ -3977,7 +3980,7 @@ TEST_P(SpdyNetworkTransactionTest, BufferedAll) { // Flush the MessageLoop while the SpdySessionDependencies (in particular, the // MockClientSocketFactory) are still alive. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Verify that we consumed all test data. helper.VerifyDataConsumed(); @@ -4072,7 +4075,7 @@ TEST_P(SpdyNetworkTransactionTest, BufferedClosed) { // Flush the MessageLoop while the SpdySessionDependencies (in particular, the // MockClientSocketFactory) are still alive. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Verify that we consumed all test data. helper.VerifyDataConsumed(); @@ -4143,7 +4146,7 @@ TEST_P(SpdyNetworkTransactionTest, BufferedCancelled) { // Flush the MessageLoop; this will cause the buffered IO task // to run for the final time. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Verify that we consumed all test data. helper.VerifyDataConsumed(); @@ -4787,7 +4790,7 @@ TEST_P(SpdyNetworkTransactionTest, VerifyRetryOnConnectionReset) { if (variant == VARIANT_RST_DURING_READ_COMPLETION) { // Writes to the socket complete asynchronously on SPDY by running // through the message loop. Complete the write here. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); } // Now schedule the ERR_CONNECTION_RESET. @@ -5102,7 +5105,7 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushClaimBeforeHeaders) { &CreateGetPushRequest(), callback.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); data.RunFor(3); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Read the server push body. std::string result2; @@ -5238,7 +5241,7 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushWithTwoHeaderFrames) { &CreateGetPushRequest(), callback.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); data.RunFor(3); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Read the server push body. std::string result2; @@ -5371,7 +5374,7 @@ TEST_P(SpdyNetworkTransactionTest, ServerPushWithNoStatusHeaderFrames) { &CreateGetPushRequest(), callback.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); data.RunFor(2); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Read the server push body. std::string result2; @@ -5726,7 +5729,7 @@ TEST_P(SpdyNetworkTransactionTest, OutOfOrderSynStream) { // Run the message loop, but do not allow the write to complete. // This leaves the SpdySession with a write pending, which prevents // SpdySession from attempting subsequent writes until this write completes. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Now, start both new transactions HttpRequestInfo info2 = CreateGetRequest(); @@ -5735,7 +5738,7 @@ TEST_P(SpdyNetworkTransactionTest, OutOfOrderSynStream) { new HttpNetworkTransaction(MEDIUM, helper.session().get())); rv = trans2->Start(&info2, callback2.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); HttpRequestInfo info3 = CreateGetRequest(); TestCompletionCallback callback3; @@ -5743,7 +5746,7 @@ TEST_P(SpdyNetworkTransactionTest, OutOfOrderSynStream) { new HttpNetworkTransaction(HIGHEST, helper.session().get())); rv = trans3->Start(&info3, callback3.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // We now have two SYN_STREAM frames queued up which will be // dequeued only once the first write completes, which we @@ -5955,7 +5958,7 @@ TEST_P(SpdyNetworkTransactionTest, WindowUpdateSent) { // Force write of WINDOW_UPDATE which was scheduled during the above // read. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); // Read EOF. data.CompleteRead(); @@ -6139,7 +6142,7 @@ TEST_P(SpdyNetworkTransactionTest, FlowControlStallResume) { int rv = trans->Start(&helper.request(), callback.callback(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); - base::MessageLoop::current()->RunUntilIdle(); // Write as much as we can. + base::RunLoop().RunUntilIdle(); // Write as much as we can. SpdyHttpStream* stream = static_cast<SpdyHttpStream*>(trans->stream_.get()); ASSERT_TRUE(stream != NULL); diff --git a/chromium/net/spdy/spdy_session.cc b/chromium/net/spdy/spdy_session.cc index baef1952652..7f31cd5cbc6 100644 --- a/chromium/net/spdy/spdy_session.cc +++ b/chromium/net/spdy/spdy_session.cc @@ -205,13 +205,38 @@ base::Value* NetLogSpdyGoAwayCallback(SpdyStreamId last_stream_id, return dict; } +// Helper function to return the total size of an array of objects +// with .size() member functions. +template <typename T, size_t N> size_t GetTotalSize(const T (&arr)[N]) { + size_t total_size = 0; + for (size_t i = 0; i < N; ++i) { + total_size += arr[i].size(); + } + return total_size; +} + +// Helper class for std:find_if on STL container containing +// SpdyStreamRequest weak pointers. +class RequestEquals { + public: + RequestEquals(const base::WeakPtr<SpdyStreamRequest>& request) + : request_(request) {} + + bool operator()(const base::WeakPtr<SpdyStreamRequest>& request) const { + return request_.get() == request.get(); + } + + private: + const base::WeakPtr<SpdyStreamRequest> request_; +}; + // The maximum number of concurrent streams we will ever create. Even if // the server permits more, we will never exceed this limit. const size_t kMaxConcurrentStreamLimit = 256; } // namespace -SpdyStreamRequest::SpdyStreamRequest() { +SpdyStreamRequest::SpdyStreamRequest() : weak_ptr_factory_(this) { Reset(); } @@ -226,9 +251,9 @@ int SpdyStreamRequest::StartRequest( RequestPriority priority, const BoundNetLog& net_log, const CompletionCallback& callback) { - DCHECK(session.get()); - DCHECK(!session_.get()); - DCHECK(!stream_.get()); + DCHECK(session); + DCHECK(!session_); + DCHECK(!stream_); DCHECK(callback_.is_null()); type_ = type; @@ -239,7 +264,7 @@ int SpdyStreamRequest::StartRequest( callback_ = callback; base::WeakPtr<SpdyStream> stream; - int rv = session->TryCreateStream(this, &stream); + int rv = session->TryCreateStream(weak_ptr_factory_.GetWeakPtr(), &stream); if (rv == OK) { Reset(); stream_ = stream; @@ -248,34 +273,36 @@ int SpdyStreamRequest::StartRequest( } void SpdyStreamRequest::CancelRequest() { - if (session_.get()) - session_->CancelStreamRequest(this); + if (session_) + session_->CancelStreamRequest(weak_ptr_factory_.GetWeakPtr()); Reset(); + // Do this to cancel any pending CompleteStreamRequest() tasks. + weak_ptr_factory_.InvalidateWeakPtrs(); } base::WeakPtr<SpdyStream> SpdyStreamRequest::ReleaseStream() { - DCHECK(!session_.get()); + DCHECK(!session_); base::WeakPtr<SpdyStream> stream = stream_; - DCHECK(stream.get()); + DCHECK(stream); Reset(); return stream; } void SpdyStreamRequest::OnRequestCompleteSuccess( - base::WeakPtr<SpdyStream>* stream) { - DCHECK(session_.get()); - DCHECK(!stream_.get()); + const base::WeakPtr<SpdyStream>& stream) { + DCHECK(session_); + DCHECK(!stream_); DCHECK(!callback_.is_null()); CompletionCallback callback = callback_; Reset(); - DCHECK(*stream); - stream_ = *stream; + DCHECK(stream); + stream_ = stream; callback.Run(OK); } void SpdyStreamRequest::OnRequestCompleteFailure(int rv) { - DCHECK(session_.get()); - DCHECK(!stream_.get()); + DCHECK(session_); + DCHECK(!stream_); DCHECK(!callback_.is_null()); CompletionCallback callback = callback_; Reset(); @@ -490,7 +517,7 @@ Error SpdySession::InitializeWithSocket( error = OK; if (error == OK) { DCHECK_NE(availability_state_, STATE_CLOSED); - connection_->AddLayeredPool(this); + connection_->AddHigherLayeredPool(this); if (enable_sending_initial_data_) SendInitialData(); pool_ = pool; @@ -565,9 +592,10 @@ Error SpdySession::TryAccessStream(const GURL& url) { return OK; } -int SpdySession::TryCreateStream(SpdyStreamRequest* request, - base::WeakPtr<SpdyStream>* stream) { - CHECK(request); +int SpdySession::TryCreateStream( + const base::WeakPtr<SpdyStreamRequest>& request, + base::WeakPtr<SpdyStream>* stream) { + DCHECK(request); if (availability_state_ == STATE_GOING_AWAY) return ERR_FAILED; @@ -641,8 +669,9 @@ int SpdySession::CreateStream(const SpdyStreamRequest& request, return OK; } -void SpdySession::CancelStreamRequest(SpdyStreamRequest* request) { - CHECK(request); +void SpdySession::CancelStreamRequest( + const base::WeakPtr<SpdyStreamRequest>& request) { + DCHECK(request); if (DCHECK_IS_ON()) { // |request| should not be in a queue not matching its priority. @@ -650,7 +679,9 @@ void SpdySession::CancelStreamRequest(SpdyStreamRequest* request) { if (request->priority() == i) continue; PendingStreamRequestQueue* queue = &pending_create_stream_queues_[i]; - DCHECK(std::find(queue->begin(), queue->end(), request) == queue->end()); + DCHECK(std::find_if(queue->begin(), + queue->end(), + RequestEquals(request)) == queue->end()); } } @@ -659,18 +690,30 @@ void SpdySession::CancelStreamRequest(SpdyStreamRequest* request) { // Remove |request| from |queue| while preserving the order of the // other elements. PendingStreamRequestQueue::iterator it = - std::find(queue->begin(), queue->end(), request); + std::find_if(queue->begin(), queue->end(), RequestEquals(request)); + // The request may already be removed if there's a + // CompleteStreamRequest() in flight. if (it != queue->end()) { it = queue->erase(it); // |request| should be in the queue at most once, and if it is // present, should not be pending completion. - DCHECK(std::find(it, queue->end(), request) == queue->end()); - DCHECK(!ContainsKey(pending_stream_request_completions_, - request)); - return; + DCHECK(std::find_if(it, queue->end(), RequestEquals(request)) == + queue->end()); } +} + +base::WeakPtr<SpdyStreamRequest> SpdySession::GetNextPendingStreamRequest() { + for (int j = NUM_PRIORITIES - 1; j >= MINIMUM_PRIORITY; --j) { + if (pending_create_stream_queues_[j].empty()) + continue; - pending_stream_request_completions_.erase(request); + base::WeakPtr<SpdyStreamRequest> pending_request = + pending_create_stream_queues_[j].front(); + DCHECK(pending_request); + pending_create_stream_queues_[j].pop_front(); + return pending_request; + } + return base::WeakPtr<SpdyStreamRequest>(); } void SpdySession::ProcessPendingStreamRequests() { @@ -684,27 +727,16 @@ void SpdySession::ProcessPendingStreamRequests() { } for (size_t i = 0; max_requests_to_process == 0 || i < max_requests_to_process; ++i) { - bool processed_request = false; - for (int j = NUM_PRIORITIES - 1; j >= MINIMUM_PRIORITY; --j) { - if (pending_create_stream_queues_[j].empty()) - continue; - - SpdyStreamRequest* pending_request = - pending_create_stream_queues_[j].front(); - CHECK(pending_request); - pending_create_stream_queues_[j].pop_front(); - processed_request = true; - DCHECK(!ContainsKey(pending_stream_request_completions_, - pending_request)); - pending_stream_request_completions_.insert(pending_request); - base::MessageLoop::current()->PostTask( - FROM_HERE, - base::Bind(&SpdySession::CompleteStreamRequest, - weak_factory_.GetWeakPtr(), pending_request)); - break; - } - if (!processed_request) + base::WeakPtr<SpdyStreamRequest> pending_request = + GetNextPendingStreamRequest(); + if (!pending_request) break; + + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(&SpdySession::CompleteStreamRequest, + weak_factory_.GetWeakPtr(), + pending_request)); } } @@ -1016,9 +1048,25 @@ void SpdySession::CloseActiveStreamIterator(ActiveStreamMap::iterator it, // push is hardly used. Write tests for this and fix this. (See // http://crbug.com/261712 .) if (owned_stream->type() == SPDY_PUSH_STREAM) - unclaimed_pushed_streams_.erase(owned_stream->url()); + unclaimed_pushed_streams_.erase(owned_stream->url()); + + base::WeakPtr<SpdySession> weak_this = GetWeakPtr(); DeleteStream(owned_stream.Pass(), status); + + if (!weak_this) + return; + + if (availability_state_ == STATE_CLOSED) + return; + + // If there are no active streams and the socket pool is stalled, close the + // session to free up a socket slot. + if (active_streams_.empty() && connection_->IsPoolStalled()) { + CloseSessionResult result = + DoCloseSession(ERR_CONNECTION_CLOSED, "Closing idle connection."); + DCHECK_NE(result, SESSION_ALREADY_CLOSED); + } } void SpdySession::CloseCreatedStreamIterator(CreatedStreamSet::iterator it, @@ -1377,7 +1425,6 @@ void SpdySession::DcheckGoingAway() const { DCHECK(pending_create_stream_queues_[i].empty()); } } - DCHECK(pending_stream_request_completions_.empty()); DCHECK(created_streams_.empty()); } @@ -1395,41 +1442,40 @@ void SpdySession::StartGoingAway(SpdyStreamId last_good_stream_id, DCHECK_GE(availability_state_, STATE_GOING_AWAY); // The loops below are carefully written to avoid reentrancy problems. - // - // TODO(akalin): Any of the functions below can cause |this| to be - // deleted, so handle that below (and add tests for it). - - for (int i = 0; i < NUM_PRIORITIES; ++i) { - PendingStreamRequestQueue queue; - queue.swap(pending_create_stream_queues_[i]); - for (PendingStreamRequestQueue::const_iterator it = queue.begin(); - it != queue.end(); ++it) { - CHECK(*it); - (*it)->OnRequestCompleteFailure(ERR_ABORTED); - } - } - PendingStreamRequestCompletionSet pending_completions; - pending_completions.swap(pending_stream_request_completions_); - for (PendingStreamRequestCompletionSet::const_iterator it = - pending_completions.begin(); - it != pending_completions.end(); ++it) { - (*it)->OnRequestCompleteFailure(ERR_ABORTED); + while (true) { + size_t old_size = GetTotalSize(pending_create_stream_queues_); + base::WeakPtr<SpdyStreamRequest> pending_request = + GetNextPendingStreamRequest(); + if (!pending_request) + break; + // No new stream requests should be added while the session is + // going away. + DCHECK_GT(old_size, GetTotalSize(pending_create_stream_queues_)); + pending_request->OnRequestCompleteFailure(ERR_ABORTED); } while (true) { + size_t old_size = active_streams_.size(); ActiveStreamMap::iterator it = active_streams_.lower_bound(last_good_stream_id + 1); if (it == active_streams_.end()) break; LogAbandonedActiveStream(it, status); CloseActiveStreamIterator(it, status); + // No new streams should be activated while the session is going + // away. + DCHECK_GT(old_size, active_streams_.size()); } while (!created_streams_.empty()) { + size_t old_size = created_streams_.size(); CreatedStreamSet::iterator it = created_streams_.begin(); LogAbandonedStream(*it, status); CloseCreatedStreamIterator(it, status); + // No new streams should be created while the session is going + // away. + DCHECK_GT(old_size, created_streams_.size()); } write_queue_.RemovePendingWritesForStreamsAfter(last_good_stream_id); @@ -2677,25 +2723,20 @@ void SpdySession::RecordHistograms() { } } -void SpdySession::CompleteStreamRequest(SpdyStreamRequest* pending_request) { - CHECK(pending_request); - - PendingStreamRequestCompletionSet::iterator it = - pending_stream_request_completions_.find(pending_request); - +void SpdySession::CompleteStreamRequest( + const base::WeakPtr<SpdyStreamRequest>& pending_request) { // Abort if the request has already been cancelled. - if (it == pending_stream_request_completions_.end()) + if (!pending_request) return; base::WeakPtr<SpdyStream> stream; int rv = CreateStream(*pending_request, &stream); - pending_stream_request_completions_.erase(it); if (rv == OK) { - DCHECK(stream.get()); - pending_request->OnRequestCompleteSuccess(&stream); + DCHECK(stream); + pending_request->OnRequestCompleteSuccess(stream); } else { - DCHECK(!stream.get()); + DCHECK(!stream); pending_request->OnRequestCompleteFailure(rv); } } @@ -2859,20 +2900,6 @@ void SpdySession::QueueSendStalledStream(const SpdyStream& stream) { stream_send_unstall_queue_[stream.priority()].push_back(stream.stream_id()); } -namespace { - -// Helper function to return the total size of an array of objects -// with .size() member functions. -template <typename T, size_t N> size_t GetTotalSize(const T (&arr)[N]) { - size_t total_size = 0; - for (size_t i = 0; i < N; ++i) { - total_size += arr[i].size(); - } - return total_size; -} - -} // namespace - void SpdySession::ResumeSendStalledStreams() { DCHECK_EQ(flow_control_state_, FLOW_CONTROL_STREAM_AND_SESSION); diff --git a/chromium/net/spdy/spdy_session.h b/chromium/net/spdy/spdy_session.h index 819db111b26..f9c0623d5d4 100644 --- a/chromium/net/spdy/spdy_session.h +++ b/chromium/net/spdy/spdy_session.h @@ -157,7 +157,7 @@ class NET_EXPORT_PRIVATE SpdyStreamRequest { // Called by |session_| when the stream attempt has finished // successfully. - void OnRequestCompleteSuccess(base::WeakPtr<SpdyStream>* stream); + void OnRequestCompleteSuccess(const base::WeakPtr<SpdyStream>& stream); // Called by |session_| when the stream attempt has finished with an // error. Also called with ERR_ABORTED if |session_| is destroyed @@ -172,6 +172,7 @@ class NET_EXPORT_PRIVATE SpdyStreamRequest { void Reset(); + base::WeakPtrFactory<SpdyStreamRequest> weak_ptr_factory_; SpdyStreamType type_; base::WeakPtr<SpdySession> session_; base::WeakPtr<SpdyStream> stream_; @@ -185,7 +186,7 @@ class NET_EXPORT_PRIVATE SpdyStreamRequest { class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, public SpdyFramerDebugVisitorInterface, - public LayeredPool { + public HigherLayeredPool { public: // TODO(akalin): Use base::TickClock when it becomes available. typedef base::TimeTicks (*TimeFunc)(void); @@ -479,7 +480,7 @@ class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, // Must be used only by |pool_|. base::WeakPtr<SpdySession> GetWeakPtr(); - // LayeredPool implementation: + // HigherLayeredPool implementation: virtual bool CloseOneIdleConnection() OVERRIDE; private: @@ -501,8 +502,8 @@ class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, FRIEND_TEST_ALL_PREFIXES(SpdySessionTest, SessionFlowControlNoSendLeaks); FRIEND_TEST_ALL_PREFIXES(SpdySessionTest, SessionFlowControlEndToEnd); - typedef std::deque<SpdyStreamRequest*> PendingStreamRequestQueue; - typedef std::set<SpdyStreamRequest*> PendingStreamRequestCompletionSet; + typedef std::deque<base::WeakPtr<SpdyStreamRequest> > + PendingStreamRequestQueue; struct ActiveStreamInfo { ActiveStreamInfo(); @@ -574,7 +575,7 @@ class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, // |request->OnRequestComplete{Success,Failure}()| will be called // when the stream is created (unless it is cancelled). Otherwise, // no stream is created and the error is returned. - int TryCreateStream(SpdyStreamRequest* request, + int TryCreateStream(const base::WeakPtr<SpdyStreamRequest>& request, base::WeakPtr<SpdyStream>* stream); // Actually create a stream into |stream|. Returns OK if successful; @@ -584,7 +585,11 @@ class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, // Called by SpdyStreamRequest to remove |request| from the stream // creation queue. - void CancelStreamRequest(SpdyStreamRequest* request); + void CancelStreamRequest(const base::WeakPtr<SpdyStreamRequest>& request); + + // Returns the next pending stream request to process, or NULL if + // there is none. + base::WeakPtr<SpdyStreamRequest> GetNextPendingStreamRequest(); // Called when there is room to create more streams (e.g., a stream // was closed). Processes as many pending stream requests as @@ -783,7 +788,8 @@ class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, // Invokes a user callback for stream creation. We provide this method so it // can be deferred to the MessageLoop, so we avoid re-entrancy problems. - void CompleteStreamRequest(SpdyStreamRequest* pending_request); + void CompleteStreamRequest( + const base::WeakPtr<SpdyStreamRequest>& pending_request); // Remove old unclaimed pushed streams. void DeleteExpiredPushedStreams(); @@ -960,12 +966,6 @@ class NET_EXPORT SpdySession : public BufferedSpdyFramerVisitorInterface, // not yet been satisfied. PendingStreamRequestQueue pending_create_stream_queues_[NUM_PRIORITIES]; - // A set of requests that are waiting to be completed (i.e., for the - // stream to actually be created). This is necessary since we kick - // off the stream creation asynchronously, and so the request may be - // cancelled before the asynchronous task to create the stream runs. - PendingStreamRequestCompletionSet pending_stream_request_completions_; - // Map from stream id to all active streams. Streams are active in the sense // that they have a consumer (typically SpdyNetworkTransaction and regardless // of whether or not there is currently any ongoing IO [might be waiting for diff --git a/chromium/net/spdy/spdy_session_pool_unittest.cc b/chromium/net/spdy/spdy_session_pool_unittest.cc index 9d0679d4d4b..33688759096 100644 --- a/chromium/net/spdy/spdy_session_pool_unittest.cc +++ b/chromium/net/spdy/spdy_session_pool_unittest.cc @@ -337,9 +337,12 @@ void SpdySessionPoolTest::RunIPPoolingTest( // This test requires that the HostResolver cache be populated. Normal // code would have done this already, but we do it manually. HostResolver::RequestInfo info(HostPortPair(test_hosts[i].name, kTestPort)); - session_deps_.host_resolver->Resolve( - info, &test_hosts[i].addresses, CompletionCallback(), NULL, - BoundNetLog()); + session_deps_.host_resolver->Resolve(info, + DEFAULT_PRIORITY, + &test_hosts[i].addresses, + CompletionCallback(), + NULL, + BoundNetLog()); // Setup a SpdySessionKey test_hosts[i].key = SpdySessionKey( diff --git a/chromium/net/spdy/spdy_session_unittest.cc b/chromium/net/spdy/spdy_session_unittest.cc index f0d448cd700..72dc5747998 100644 --- a/chromium/net/spdy/spdy_session_unittest.cc +++ b/chromium/net/spdy/spdy_session_unittest.cc @@ -193,6 +193,96 @@ TEST_P(SpdySessionTest, InitialReadError) { spdy_session_pool_, key_, ERR_FAILED); } +namespace { + +// A helper class that vends a callback that, when fired, destroys a +// given SpdyStreamRequest. +class StreamRequestDestroyingCallback : public TestCompletionCallbackBase { + public: + StreamRequestDestroyingCallback() {} + + virtual ~StreamRequestDestroyingCallback() {} + + void SetRequestToDestroy(scoped_ptr<SpdyStreamRequest> request) { + request_ = request.Pass(); + } + + CompletionCallback MakeCallback() { + return base::Bind(&StreamRequestDestroyingCallback::OnComplete, + base::Unretained(this)); + } + + private: + void OnComplete(int result) { + request_.reset(); + SetResult(result); + } + + scoped_ptr<SpdyStreamRequest> request_; +}; + +} // namespace + +// Request kInitialMaxConcurrentStreams streams. Request two more +// streams, but have the callback for one destroy the second stream +// request. Close the session. Nothing should blow up. This is a +// regression test for http://crbug.com/250841 . +TEST_P(SpdySessionTest, PendingStreamCancellingAnother) { + session_deps_.host_resolver->set_synchronous_mode(true); + + MockRead reads[] = {MockRead(ASYNC, 0, 0), }; + + DeterministicSocketData data(reads, arraysize(reads), NULL, 0); + MockConnect connect_data(SYNCHRONOUS, OK); + data.set_connect_data(connect_data); + session_deps_.deterministic_socket_factory->AddSocketDataProvider(&data); + + SSLSocketDataProvider ssl(SYNCHRONOUS, OK); + session_deps_.deterministic_socket_factory->AddSSLSocketDataProvider(&ssl); + + CreateDeterministicNetworkSession(); + + base::WeakPtr<SpdySession> session = + CreateInsecureSpdySession(http_session_, key_, BoundNetLog()); + + // Create the maximum number of concurrent streams. + for (size_t i = 0; i < kInitialMaxConcurrentStreams; ++i) { + base::WeakPtr<SpdyStream> spdy_stream = CreateStreamSynchronously( + SPDY_BIDIRECTIONAL_STREAM, session, test_url_, MEDIUM, BoundNetLog()); + ASSERT_TRUE(spdy_stream != NULL); + } + + SpdyStreamRequest request1; + scoped_ptr<SpdyStreamRequest> request2(new SpdyStreamRequest); + + StreamRequestDestroyingCallback callback1; + ASSERT_EQ(ERR_IO_PENDING, + request1.StartRequest(SPDY_BIDIRECTIONAL_STREAM, + session, + test_url_, + MEDIUM, + BoundNetLog(), + callback1.MakeCallback())); + + // |callback2| is never called. + TestCompletionCallback callback2; + ASSERT_EQ(ERR_IO_PENDING, + request2->StartRequest(SPDY_BIDIRECTIONAL_STREAM, + session, + test_url_, + MEDIUM, + BoundNetLog(), + callback2.callback())); + + callback1.SetRequestToDestroy(request2.Pass()); + + session->CloseSessionOnError(ERR_ABORTED, "Aborting session"); + + EXPECT_EQ(ERR_ABORTED, callback1.WaitForResult()); + + data.RunFor(1); +} + // A session receiving a GOAWAY frame with no active streams should // immediately close. TEST_P(SpdySessionTest, GoAwayWithNoActiveStreams) { @@ -2669,7 +2759,7 @@ TEST_P(SpdySessionTest, CloseOneIdleConnection) { TestCompletionCallback callback2; HostPortPair host_port2("2.com", 80); scoped_refptr<TransportSocketParams> params2( - new TransportSocketParams(host_port2, DEFAULT_PRIORITY, false, false, + new TransportSocketParams(host_port2, false, false, OnHostResolutionCallback())); scoped_ptr<ClientSocketHandle> connection2(new ClientSocketHandle); EXPECT_EQ(ERR_IO_PENDING, @@ -2731,8 +2821,12 @@ TEST_P(SpdySessionTest, CloseOneIdleConnectionWithAlias) { AddressList addresses; // Pre-populate the DNS cache, since a synchronous resolution is required in // order to create the alias. - session_deps_.host_resolver->Resolve( - info, &addresses, CompletionCallback(), NULL, BoundNetLog()); + session_deps_.host_resolver->Resolve(info, + DEFAULT_PRIORITY, + &addresses, + CompletionCallback(), + NULL, + BoundNetLog()); // Get a session for |key2|, which should return the session created earlier. base::WeakPtr<SpdySession> session2 = spdy_session_pool_->FindAvailableSession(key2, BoundNetLog()); @@ -2744,7 +2838,7 @@ TEST_P(SpdySessionTest, CloseOneIdleConnectionWithAlias) { TestCompletionCallback callback3; HostPortPair host_port3("3.com", 80); scoped_refptr<TransportSocketParams> params3( - new TransportSocketParams(host_port3, DEFAULT_PRIORITY, false, false, + new TransportSocketParams(host_port3, false, false, OnHostResolutionCallback())); scoped_ptr<ClientSocketHandle> connection3(new ClientSocketHandle); EXPECT_EQ(ERR_IO_PENDING, @@ -2760,9 +2854,9 @@ TEST_P(SpdySessionTest, CloseOneIdleConnectionWithAlias) { EXPECT_TRUE(session2 == NULL); } -// Tests that a non-SPDY request can't close a SPDY session that's currently in -// use. -TEST_P(SpdySessionTest, CloseOneIdleConnectionFailsWhenSessionInUse) { +// Tests that when a SPDY session becomes idle, it closes itself if there is +// a lower layer pool stalled on the per-pool socket limit. +TEST_P(SpdySessionTest, CloseSessionOnIdleWhenPoolStalled) { ClientSocketPoolManager::set_max_sockets_per_group( HttpNetworkSession::NORMAL_SOCKET_POOL, 1); ClientSocketPoolManager::set_max_sockets_per_pool( @@ -2781,10 +2875,19 @@ TEST_P(SpdySessionTest, CloseOneIdleConnectionFailsWhenSessionInUse) { CreateMockWrite(*cancel1, 1), }; StaticSocketDataProvider data(reads, arraysize(reads), - writes, arraysize(writes)); + writes, arraysize(writes)); data.set_connect_data(connect_data); session_deps_.socket_factory->AddSocketDataProvider(&data); + MockRead http_reads[] = { + MockRead(SYNCHRONOUS, ERR_IO_PENDING) // Stall forever. + }; + StaticSocketDataProvider http_data(http_reads, arraysize(http_reads), + NULL, 0); + http_data.set_connect_data(connect_data); + session_deps_.socket_factory->AddSocketDataProvider(&http_data); + + CreateNetworkSession(); TransportClientSocketPool* pool = @@ -2824,7 +2927,7 @@ TEST_P(SpdySessionTest, CloseOneIdleConnectionFailsWhenSessionInUse) { TestCompletionCallback callback2; HostPortPair host_port2("2.com", 80); scoped_refptr<TransportSocketParams> params2( - new TransportSocketParams(host_port2, DEFAULT_PRIORITY, false, false, + new TransportSocketParams(host_port2, false, false, OnHostResolutionCallback())); scoped_ptr<ClientSocketHandle> connection2(new ClientSocketHandle); EXPECT_EQ(ERR_IO_PENDING, @@ -2839,14 +2942,13 @@ TEST_P(SpdySessionTest, CloseOneIdleConnectionFailsWhenSessionInUse) { EXPECT_TRUE(pool->IsStalled()); EXPECT_FALSE(callback2.have_result()); - // Cancelling the request should still not release the session's socket, - // since the session is still kept alive by the SpdySessionPool. + // Cancelling the request should result in the session's socket being + // closed, since the pool is stalled. ASSERT_TRUE(spdy_stream1.get()); spdy_stream1->Cancel(); base::RunLoop().RunUntilIdle(); - EXPECT_TRUE(pool->IsStalled()); - EXPECT_FALSE(callback2.have_result()); - EXPECT_TRUE(session1 != NULL); + ASSERT_FALSE(pool->IsStalled()); + EXPECT_EQ(OK, callback2.WaitForResult()); } // Verify that SpdySessionKey and therefore SpdySession is different when diff --git a/chromium/net/spdy/spdy_stream.cc b/chromium/net/spdy/spdy_stream.cc index 18d20e8d3a4..a603a6c93e0 100644 --- a/chromium/net/spdy/spdy_stream.cc +++ b/chromium/net/spdy/spdy_stream.cc @@ -740,7 +740,7 @@ int SpdyStream::DoGetDomainBoundCert() { io_state_ = STATE_GET_DOMAIN_BOUND_CERT_COMPLETE; ServerBoundCertService* sbc_service = session_->GetServerBoundCertService(); DCHECK(sbc_service != NULL); - int rv = sbc_service->GetDomainBoundCert( + int rv = sbc_service->GetOrCreateDomainBoundCert( url.GetOrigin().host(), &domain_bound_private_key_, &domain_bound_cert_, diff --git a/chromium/net/spdy/spdy_test_util_common.cc b/chromium/net/spdy/spdy_test_util_common.cc index 4383db0860a..8fb085a3d8c 100644 --- a/chromium/net/spdy/spdy_test_util_common.cc +++ b/chromium/net/spdy/spdy_test_util_common.cc @@ -500,7 +500,7 @@ base::WeakPtr<SpdySession> CreateSpdySessionHelper( scoped_refptr<TransportSocketParams> transport_params( new TransportSocketParams( - key.host_port_pair(), MEDIUM, false, false, + key.host_port_pair(), false, false, OnHostResolutionCallback())); scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); @@ -509,13 +509,10 @@ base::WeakPtr<SpdySession> CreateSpdySessionHelper( int rv = ERR_UNEXPECTED; if (is_secure) { SSLConfig ssl_config; - scoped_refptr<SOCKSSocketParams> socks_params; - scoped_refptr<HttpProxySocketParams> http_proxy_params; scoped_refptr<SSLSocketParams> ssl_params( new SSLSocketParams(transport_params, - socks_params, - http_proxy_params, - ProxyServer::SCHEME_DIRECT, + NULL, + NULL, key.host_port_pair(), ssl_config, key.privacy_mode(), @@ -649,8 +646,8 @@ base::WeakPtr<SpdySession> CreateFakeSpdySessionHelper( EXPECT_FALSE(HasSpdySession(pool, key)); base::WeakPtr<SpdySession> spdy_session; scoped_ptr<ClientSocketHandle> handle(new ClientSocketHandle()); - handle->set_socket(new FakeSpdySessionClientSocket( - expected_status == OK ? ERR_IO_PENDING : expected_status)); + handle->SetSocket(scoped_ptr<StreamSocket>(new FakeSpdySessionClientSocket( + expected_status == OK ? ERR_IO_PENDING : expected_status))); EXPECT_EQ( expected_status, pool->CreateAvailableSessionFromSocket( @@ -942,7 +939,8 @@ SpdyFrame* SpdyTestUtil::ConstructSpdyGet(const char* const extra_headers[], SpdyFrame* SpdyTestUtil::ConstructSpdyConnect( const char* const extra_headers[], int extra_header_count, - int stream_id) const { + int stream_id, + RequestPriority priority) const { const char* const kConnectHeaders[] = { GetMethodKey(), "CONNECT", GetPathKey(), "www.google.com:443", @@ -953,7 +951,7 @@ SpdyFrame* SpdyTestUtil::ConstructSpdyConnect( extra_header_count, /*compressed*/ false, stream_id, - LOWEST, + priority, SYN_STREAM, CONTROL_FLAG_NONE, kConnectHeaders, diff --git a/chromium/net/spdy/spdy_test_util_common.h b/chromium/net/spdy/spdy_test_util_common.h index aa833ff91f6..e9f453c13c5 100644 --- a/chromium/net/spdy/spdy_test_util_common.h +++ b/chromium/net/spdy/spdy_test_util_common.h @@ -421,7 +421,8 @@ class SpdyTestUtil { // Constructs a standard SPDY SYN_STREAM frame for a CONNECT request. SpdyFrame* ConstructSpdyConnect(const char* const extra_headers[], int extra_header_count, - int stream_id) const; + int stream_id, + RequestPriority priority) const; // Constructs a standard SPDY push SYN frame. // |extra_headers| are the extra header-value pairs, which typically diff --git a/chromium/net/spdy/write_blocked_list.h b/chromium/net/spdy/write_blocked_list.h new file mode 100644 index 00000000000..fe5668fae8a --- /dev/null +++ b/chromium/net/spdy/write_blocked_list.h @@ -0,0 +1,86 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_SPDY_WRITE_BLOCKED_LIST_H_ +#define NET_SPDY_WRITE_BLOCKED_LIST_H_ + +#include <algorithm> +#include <deque> + +#include "base/logging.h" + +namespace net { + +const int kHighestPriority = 0; +const int kLowestPriority = 7; + +template <typename IdType> +class WriteBlockedList { + public: + // 0(1) size lookup. 0(1) insert at front or back. + typedef std::deque<IdType> BlockedList; + typedef typename BlockedList::iterator iterator; + + // Returns the priority of the highest priority list with sessions on it, or + // -1 if none of the lists have pending sessions. + int GetHighestPriorityWriteBlockedList() const { + for (int i = 0; i <= kLowestPriority; ++i) { + if (write_blocked_lists_[i].size() > 0) + return i; + } + return -1; + } + + int PopFront(int priority) { + DCHECK(!write_blocked_lists_[priority].empty()); + IdType stream_id = write_blocked_lists_[priority].front(); + write_blocked_lists_[priority].pop_front(); + return stream_id; + } + + bool HasWriteBlockedStreamsGreaterThanPriority(int priority) const { + for (int i = kHighestPriority; i < priority; ++i) { + if (!write_blocked_lists_[i].empty()) { + return true; + } + } + return false; + } + + bool HasWriteBlockedStreams() const { + return HasWriteBlockedStreamsGreaterThanPriority(kLowestPriority + 1); + } + + void PushBack(IdType stream_id, int priority) { + write_blocked_lists_[priority].push_back(stream_id); + } + + void RemoveStreamFromWriteBlockedList(IdType stream_id, int priority) { + iterator it = std::find(write_blocked_lists_[priority].begin(), + write_blocked_lists_[priority].end(), + stream_id); + while (it != write_blocked_lists_[priority].end()) { + write_blocked_lists_[priority].erase(it); + it = std::find(write_blocked_lists_[priority].begin(), + write_blocked_lists_[priority].end(), + stream_id); + } + } + + int NumBlockedStreams() { + int num_blocked_streams = 0; + for (int i = kHighestPriority; i <= kLowestPriority; ++i) { + num_blocked_streams += write_blocked_lists_[i].size(); + } + return num_blocked_streams; + } + + private: + // Priority ranges from 0 to 7 + BlockedList write_blocked_lists_[8]; +}; + +} // namespace net + +#endif // NET_SPDY_WRITE_BLOCKED_LIST_H_ diff --git a/chromium/net/spdy/write_blocked_list_test.cc b/chromium/net/spdy/write_blocked_list_test.cc new file mode 100644 index 00000000000..3bb49df94f4 --- /dev/null +++ b/chromium/net/spdy/write_blocked_list_test.cc @@ -0,0 +1,76 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/spdy/write_blocked_list.h" + +#include "testing/gtest/include/gtest/gtest.h" + + +namespace net { +namespace test { +namespace { + +typedef WriteBlockedList<int> IntWriteBlockedList; + +TEST(WriteBlockedListTest, GetHighestPriority) { + IntWriteBlockedList list; + EXPECT_EQ(-1, list.GetHighestPriorityWriteBlockedList()); + list.PushBack(1, 1); + EXPECT_EQ(1, list.GetHighestPriorityWriteBlockedList()); + list.PushBack(1, 0); + EXPECT_EQ(0, list.GetHighestPriorityWriteBlockedList()); +} + +TEST(WriteBlockedListTest, HasWriteBlockedStreamsOfGreaterThanPriority) { + IntWriteBlockedList list; + list.PushBack(1, 4); + EXPECT_TRUE(list.HasWriteBlockedStreamsGreaterThanPriority(5)); + EXPECT_FALSE(list.HasWriteBlockedStreamsGreaterThanPriority(4)); + list.PushBack(1, 2); + EXPECT_TRUE(list.HasWriteBlockedStreamsGreaterThanPriority(3)); + EXPECT_FALSE(list.HasWriteBlockedStreamsGreaterThanPriority(2)); +} + +TEST(WriteBlockedListTest, RemoveStreamFromWriteBlockedList) { + IntWriteBlockedList list; + + list.PushBack(1, 4); + EXPECT_TRUE(list.HasWriteBlockedStreams()); + + list.RemoveStreamFromWriteBlockedList(1, 5); + EXPECT_TRUE(list.HasWriteBlockedStreams()); + + list.PushBack(2, 4); + list.PushBack(1, 4); + list.RemoveStreamFromWriteBlockedList(1, 4); + list.RemoveStreamFromWriteBlockedList(2, 4); + EXPECT_FALSE(list.HasWriteBlockedStreams()); + + list.PushBack(1, 7); + EXPECT_TRUE(list.HasWriteBlockedStreams()); +} + +TEST(WriteBlockedListTest, PopFront) { + IntWriteBlockedList list; + + list.PushBack(1, 4); + EXPECT_EQ(1, list.NumBlockedStreams()); + list.PushBack(2, 4); + list.PushBack(1, 4); + list.PushBack(3, 4); + EXPECT_EQ(4, list.NumBlockedStreams()); + + EXPECT_EQ(1, list.PopFront(4)); + EXPECT_EQ(2, list.PopFront(4)); + EXPECT_EQ(1, list.PopFront(4)); + EXPECT_EQ(1, list.NumBlockedStreams()); + EXPECT_EQ(3, list.PopFront(4)); +} + +} // namespace +} // namespace test +} // namespace net + + + diff --git a/chromium/net/ssl/server_bound_cert_service.cc b/chromium/net/ssl/server_bound_cert_service.cc index 4bc82ed5d9b..2bbcbc79e6b 100644 --- a/chromium/net/ssl/server_bound_cert_service.cc +++ b/chromium/net/ssl/server_bound_cert_service.cc @@ -43,7 +43,8 @@ const int kValidityPeriodInDays = 365; const int kSystemTimeValidityBufferInDays = 90; // Used by the GetDomainBoundCertResult histogram to record the final -// outcome of each GetDomainBoundCert call. Do not re-use values. +// outcome of each GetDomainBoundCert or GetOrCreateDomainBoundCert call. +// Do not re-use values. enum GetCertResult { // Synchronously found and returned an existing domain bound cert. SYNC_SUCCESS = 0, @@ -57,7 +58,8 @@ enum GetCertResult { ASYNC_FAILURE_CREATE_CERT = 4, ASYNC_FAILURE_EXPORT_KEY = 5, ASYNC_FAILURE_UNKNOWN = 6, - // GetDomainBoundCert was called with invalid arguments. + // GetDomainBoundCert or GetOrCreateDomainBoundCert was called with + // invalid arguments. INVALID_ARGUMENT = 7, // We don't support any of the cert types the server requested. UNSUPPORTED_TYPE = 8, @@ -277,14 +279,18 @@ class ServerBoundCertServiceWorker { // origin message loop. class ServerBoundCertServiceJob { public: - ServerBoundCertServiceJob() { } + ServerBoundCertServiceJob(bool create_if_missing) + : create_if_missing_(create_if_missing) { + } ~ServerBoundCertServiceJob() { if (!requests_.empty()) DeleteAllCanceled(); } - void AddRequest(ServerBoundCertServiceRequest* request) { + void AddRequest(ServerBoundCertServiceRequest* request, + bool create_if_missing = false) { + create_if_missing_ |= create_if_missing; requests_.push_back(request); } @@ -294,6 +300,8 @@ class ServerBoundCertServiceJob { PostAll(error, private_key, cert); } + bool CreateIfMissing() const { return create_if_missing_; } + private: void PostAll(int error, const std::string& private_key, @@ -320,6 +328,7 @@ class ServerBoundCertServiceJob { } std::vector<ServerBoundCertServiceRequest*> requests_; + bool create_if_missing_; }; // static @@ -388,7 +397,7 @@ std::string ServerBoundCertService::GetDomainForHost(const std::string& host) { return domain; } -int ServerBoundCertService::GetDomainBoundCert( +int ServerBoundCertService::GetOrCreateDomainBoundCert( const std::string& host, std::string* private_key, std::string* cert, @@ -411,49 +420,15 @@ int ServerBoundCertService::GetDomainBoundCert( requests_++; - // See if an identical request is currently in flight. - ServerBoundCertServiceJob* job = NULL; - std::map<std::string, ServerBoundCertServiceJob*>::const_iterator j; - j = inflight_.find(domain); - if (j != inflight_.end()) { - // An identical request is in flight already. We'll just attach our - // callback. - job = j->second; - inflight_joins_++; - - ServerBoundCertServiceRequest* request = new ServerBoundCertServiceRequest( - request_start, - base::Bind(&RequestHandle::OnRequestComplete, - base::Unretained(out_req)), - private_key, cert); - job->AddRequest(request); - out_req->RequestStarted(this, request, callback); + // See if a request for the same domain is currently in flight. + bool create_if_missing = true; + if (JoinToInFlightRequest(request_start, domain, private_key, cert, + create_if_missing, callback, out_req)) { return ERR_IO_PENDING; } - // Check if a domain bound cert of an acceptable type already exists for this - // domain. Note that |expiration_time| is ignored, and expired certs are - // considered valid. - base::Time expiration_time; - int err = server_bound_cert_store_->GetServerBoundCert( - domain, - &expiration_time /* ignored */, - private_key, - cert, - base::Bind(&ServerBoundCertService::GotServerBoundCert, - weak_ptr_factory_.GetWeakPtr())); - - if (err == OK) { - // Sync lookup found a valid cert. - DVLOG(1) << "Cert store had valid cert for " << domain; - cert_store_hits_++; - RecordGetDomainBoundCertResult(SYNC_SUCCESS); - base::TimeDelta request_time = base::TimeTicks::Now() - request_start; - UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time); - RecordGetCertTime(request_time); - return OK; - } - + int err = LookupDomainBoundCert(request_start, domain, private_key, cert, + create_if_missing, callback, out_req); if (err == ERR_FILE_NOT_FOUND) { // Sync lookup did not find a valid cert. Start generating a new one. workers_created_++; @@ -467,19 +442,17 @@ int ServerBoundCertService::GetDomainBoundCert( RecordGetDomainBoundCertResult(WORKER_FAILURE); return ERR_INSUFFICIENT_RESOURCES; } - } - - if (err == ERR_IO_PENDING || err == ERR_FILE_NOT_FOUND) { - // We are either waiting for async DB lookup, or waiting for cert - // generation. Create a job & request to track it. - job = new ServerBoundCertServiceJob(); + // We are waiting for cert generation. Create a job & request to track it. + ServerBoundCertServiceJob* job = + new ServerBoundCertServiceJob(create_if_missing); inflight_[domain] = job; ServerBoundCertServiceRequest* request = new ServerBoundCertServiceRequest( request_start, base::Bind(&RequestHandle::OnRequestComplete, base::Unretained(out_req)), - private_key, cert); + private_key, + cert); job->AddRequest(request); out_req->RequestStarted(this, request, callback); return ERR_IO_PENDING; @@ -488,6 +461,41 @@ int ServerBoundCertService::GetDomainBoundCert( return err; } +int ServerBoundCertService::GetDomainBoundCert( + const std::string& host, + std::string* private_key, + std::string* cert, + const CompletionCallback& callback, + RequestHandle* out_req) { + DVLOG(1) << __FUNCTION__ << " " << host; + DCHECK(CalledOnValidThread()); + base::TimeTicks request_start = base::TimeTicks::Now(); + + if (callback.is_null() || !private_key || !cert || host.empty()) { + RecordGetDomainBoundCertResult(INVALID_ARGUMENT); + return ERR_INVALID_ARGUMENT; + } + + std::string domain = GetDomainForHost(host); + if (domain.empty()) { + RecordGetDomainBoundCertResult(INVALID_ARGUMENT); + return ERR_INVALID_ARGUMENT; + } + + requests_++; + + // See if a request for the same domain currently in flight. + bool create_if_missing = false; + if (JoinToInFlightRequest(request_start, domain, private_key, cert, + create_if_missing, callback, out_req)) { + return ERR_IO_PENDING; + } + + int err = LookupDomainBoundCert(request_start, domain, private_key, cert, + create_if_missing, callback, out_req); + return err; +} + void ServerBoundCertService::GotServerBoundCert( int err, const std::string& server_identifier, @@ -511,7 +519,13 @@ void ServerBoundCertService::GotServerBoundCert( HandleResult(OK, server_identifier, key, cert); return; } - // Async lookup did not find a valid cert. Start generating a new one. + // Async lookup did not find a valid cert. If no request asked to create one, + // return the error directly. + if (!j->second->CreateIfMissing()) { + HandleResult(err, server_identifier, key, cert); + return; + } + // At least one request asked to create a cert => start generating a new one. workers_created_++; ServerBoundCertServiceWorker* worker = new ServerBoundCertServiceWorker( server_identifier, @@ -524,7 +538,6 @@ void ServerBoundCertService::GotServerBoundCert( server_identifier, std::string(), std::string()); - return; } } @@ -579,6 +592,86 @@ void ServerBoundCertService::HandleResult( delete job; } +bool ServerBoundCertService::JoinToInFlightRequest( + const base::TimeTicks& request_start, + const std::string& domain, + std::string* private_key, + std::string* cert, + bool create_if_missing, + const CompletionCallback& callback, + RequestHandle* out_req) { + ServerBoundCertServiceJob* job = NULL; + std::map<std::string, ServerBoundCertServiceJob*>::const_iterator j = + inflight_.find(domain); + if (j != inflight_.end()) { + // A request for the same domain is in flight already. We'll attach our + // callback, but we'll also mark it as requiring a cert if one's mising. + job = j->second; + inflight_joins_++; + + ServerBoundCertServiceRequest* request = new ServerBoundCertServiceRequest( + request_start, + base::Bind(&RequestHandle::OnRequestComplete, + base::Unretained(out_req)), + private_key, + cert); + job->AddRequest(request, create_if_missing); + out_req->RequestStarted(this, request, callback); + return true; + } + return false; +} + +int ServerBoundCertService::LookupDomainBoundCert( + const base::TimeTicks& request_start, + const std::string& domain, + std::string* private_key, + std::string* cert, + bool create_if_missing, + const CompletionCallback& callback, + RequestHandle* out_req) { + // Check if a domain bound cert already exists for this domain. Note that + // |expiration_time| is ignored, and expired certs are considered valid. + base::Time expiration_time; + int err = server_bound_cert_store_->GetServerBoundCert( + domain, + &expiration_time /* ignored */, + private_key, + cert, + base::Bind(&ServerBoundCertService::GotServerBoundCert, + weak_ptr_factory_.GetWeakPtr())); + + if (err == OK) { + // Sync lookup found a valid cert. + DVLOG(1) << "Cert store had valid cert for " << domain; + cert_store_hits_++; + RecordGetDomainBoundCertResult(SYNC_SUCCESS); + base::TimeDelta request_time = base::TimeTicks::Now() - request_start; + UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time); + RecordGetCertTime(request_time); + return OK; + } + + if (err == ERR_IO_PENDING) { + // We are waiting for async DB lookup. Create a job & request to track it. + ServerBoundCertServiceJob* job = + new ServerBoundCertServiceJob(create_if_missing); + inflight_[domain] = job; + + ServerBoundCertServiceRequest* request = new ServerBoundCertServiceRequest( + request_start, + base::Bind(&RequestHandle::OnRequestComplete, + base::Unretained(out_req)), + private_key, + cert); + job->AddRequest(request); + out_req->RequestStarted(this, request, callback); + return ERR_IO_PENDING; + } + + return err; +} + int ServerBoundCertService::cert_count() { return server_bound_cert_store_->GetCertCount(); } diff --git a/chromium/net/ssl/server_bound_cert_service.h b/chromium/net/ssl/server_bound_cert_service.h index d931ec87082..0dc7f4ae390 100644 --- a/chromium/net/ssl/server_bound_cert_service.h +++ b/chromium/net/ssl/server_bound_cert_service.h @@ -106,6 +106,31 @@ class NET_EXPORT ServerBoundCertService // |*out_req| will be initialized with a handle to the async request. This // RequestHandle object must be cancelled or destroyed before the // ServerBoundCertService is destroyed. + int GetOrCreateDomainBoundCert( + const std::string& host, + std::string* private_key, + std::string* cert, + const CompletionCallback& callback, + RequestHandle* out_req); + + // Fetches the domain bound cert for the specified host if one exists. + // Returns OK if successful, ERR_FILE_NOT_FOUND if none exists, or an error + // code upon failure. + // + // On successful completion, |private_key| stores a DER-encoded + // PrivateKeyInfo struct, and |cert| stores a DER-encoded certificate. + // The PrivateKeyInfo is always an ECDSA private key. + // + // |callback| must not be null. ERR_IO_PENDING is returned if the operation + // could not be completed immediately, in which case the result code will + // be passed to the callback when available. If an in-flight + // GetDomainBoundCert is pending, and a new GetOrCreateDomainBoundCert + // request arrives for the same domain, the GetDomainBoundCert request will + // not complete until a new cert is created. + // + // |*out_req| will be initialized with a handle to the async request. This + // RequestHandle object must be cancelled or destroyed before the + // ServerBoundCertService is destroyed. int GetDomainBoundCert( const std::string& host, std::string* private_key, @@ -143,6 +168,29 @@ class NET_EXPORT ServerBoundCertService const std::string& private_key, const std::string& cert); + // Searches for an in-flight request for the same domain. If found, + // attaches to the request and returns true. Returns false if no in-flight + // request is found. + bool JoinToInFlightRequest(const base::TimeTicks& request_start, + const std::string& domain, + std::string* private_key, + std::string* cert, + bool create_if_missing, + const CompletionCallback& callback, + RequestHandle* out_req); + + // Looks for the domain bound cert for |domain| in this service's store. + // Returns OK if it can be found synchronously, ERR_IO_PENDING if the + // result cannot be obtained synchronously, or a network error code on + // failure (including failure to find a domain-bound cert of |domain|). + int LookupDomainBoundCert(const base::TimeTicks& request_start, + const std::string& domain, + std::string* private_key, + std::string* cert, + bool create_if_missing, + const CompletionCallback& callback, + RequestHandle* out_req); + scoped_ptr<ServerBoundCertStore> server_bound_cert_store_; scoped_refptr<base::TaskRunner> task_runner_; diff --git a/chromium/net/ssl/server_bound_cert_service_unittest.cc b/chromium/net/ssl/server_bound_cert_service_unittest.cc index d7b8553b5ac..fc25b0b9322 100644 --- a/chromium/net/ssl/server_bound_cert_service_unittest.cc +++ b/chromium/net/ssl/server_bound_cert_service_unittest.cc @@ -10,8 +10,8 @@ #include "base/bind.h" #include "base/memory/scoped_ptr.h" #include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_proxy.h" #include "base/task_runner.h" -#include "base/threading/sequenced_worker_pool.h" #include "crypto/ec_private_key.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -28,23 +28,37 @@ void FailTest(int /* result */) { FAIL(); } +// Simple task runner that refuses to actually post any tasks. This simulates +// a TaskRunner that has been shutdown, by returning false for any attempt to +// add new tasks. +class FailingTaskRunner : public base::TaskRunner { + public: + FailingTaskRunner() {} + + virtual bool PostDelayedTask(const tracked_objects::Location& from_here, + const base::Closure& task, + base::TimeDelta delay) OVERRIDE { + return false; + } + + virtual bool RunsTasksOnCurrentThread() const OVERRIDE { return true; } + + protected: + virtual ~FailingTaskRunner() {} + + private: + DISALLOW_COPY_AND_ASSIGN(FailingTaskRunner); +}; + class ServerBoundCertServiceTest : public testing::Test { public: ServerBoundCertServiceTest() - : sequenced_worker_pool_(new base::SequencedWorkerPool( - 3, "ServerBoundCertServiceTest")), - service_(new ServerBoundCertService( + : service_(new ServerBoundCertService( new DefaultServerBoundCertStore(NULL), - sequenced_worker_pool_)) { - } - - virtual ~ServerBoundCertServiceTest() { - if (sequenced_worker_pool_.get()) - sequenced_worker_pool_->Shutdown(); + base::MessageLoopProxy::current())) { } protected: - scoped_refptr<base::SequencedWorkerPool> sequenced_worker_pool_; scoped_ptr<ServerBoundCertService> service_; }; @@ -136,6 +150,24 @@ TEST_F(ServerBoundCertServiceTest, GetDomainForHost) { // See http://crbug.com/91512 - implement OpenSSL version of CreateSelfSigned. #if !defined(USE_OPENSSL) +TEST_F(ServerBoundCertServiceTest, GetCacheMiss) { + std::string host("encrypted.google.com"); + + int error; + TestCompletionCallback callback; + ServerBoundCertService::RequestHandle request_handle; + + // Synchronous completion, because the store is initialized. + std::string private_key, der_cert; + EXPECT_EQ(0, service_->cert_count()); + error = service_->GetDomainBoundCert( + host, &private_key, &der_cert, callback.callback(), &request_handle); + EXPECT_EQ(ERR_FILE_NOT_FOUND, error); + EXPECT_FALSE(request_handle.is_active()); + EXPECT_EQ(0, service_->cert_count()); + EXPECT_TRUE(der_cert.empty()); +} + TEST_F(ServerBoundCertServiceTest, CacheHit) { std::string host("encrypted.google.com"); @@ -146,7 +178,7 @@ TEST_F(ServerBoundCertServiceTest, CacheHit) { // Asynchronous completion. std::string private_key_info1, der_cert1; EXPECT_EQ(0, service_->cert_count()); - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info1, &der_cert1, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); @@ -160,7 +192,7 @@ TEST_F(ServerBoundCertServiceTest, CacheHit) { // Synchronous completion. std::string private_key_info2, der_cert2; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info2, &der_cert2, callback.callback(), &request_handle); EXPECT_FALSE(request_handle.is_active()); @@ -169,8 +201,19 @@ TEST_F(ServerBoundCertServiceTest, CacheHit) { EXPECT_EQ(private_key_info1, private_key_info2); EXPECT_EQ(der_cert1, der_cert2); - EXPECT_EQ(2u, service_->requests()); - EXPECT_EQ(1u, service_->cert_store_hits()); + // Synchronous get. + std::string private_key_info3, der_cert3; + error = service_->GetDomainBoundCert( + host, &private_key_info3, &der_cert3, callback.callback(), + &request_handle); + EXPECT_FALSE(request_handle.is_active()); + EXPECT_EQ(OK, error); + EXPECT_EQ(1, service_->cert_count()); + EXPECT_EQ(der_cert1, der_cert3); + EXPECT_EQ(private_key_info1, private_key_info3); + + EXPECT_EQ(3u, service_->requests()); + EXPECT_EQ(2u, service_->cert_store_hits()); EXPECT_EQ(0u, service_->inflight_joins()); } @@ -182,7 +225,7 @@ TEST_F(ServerBoundCertServiceTest, StoreCerts) { std::string host1("encrypted.google.com"); std::string private_key_info1, der_cert1; EXPECT_EQ(0, service_->cert_count()); - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host1, &private_key_info1, &der_cert1, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); @@ -193,7 +236,7 @@ TEST_F(ServerBoundCertServiceTest, StoreCerts) { std::string host2("www.verisign.com"); std::string private_key_info2, der_cert2; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host2, &private_key_info2, &der_cert2, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); @@ -204,7 +247,7 @@ TEST_F(ServerBoundCertServiceTest, StoreCerts) { std::string host3("www.twitter.com"); std::string private_key_info3, der_cert3; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host3, &private_key_info3, &der_cert3, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); @@ -234,13 +277,13 @@ TEST_F(ServerBoundCertServiceTest, InflightJoin) { TestCompletionCallback callback2; ServerBoundCertService::RequestHandle request_handle2; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info1, &der_cert1, callback1.callback(), &request_handle1); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle1.is_active()); // Should join with the original request. - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info2, &der_cert2, callback2.callback(), &request_handle2); EXPECT_EQ(ERR_IO_PENDING, error); @@ -254,6 +297,45 @@ TEST_F(ServerBoundCertServiceTest, InflightJoin) { EXPECT_EQ(2u, service_->requests()); EXPECT_EQ(0u, service_->cert_store_hits()); EXPECT_EQ(1u, service_->inflight_joins()); + EXPECT_EQ(1u, service_->workers_created()); +} + +// Tests an inflight join of a Get request to a GetOrCreate request. +TEST_F(ServerBoundCertServiceTest, InflightJoinGetOrCreateAndGet) { + std::string host("encrypted.google.com"); + int error; + + std::string private_key_info1, der_cert1; + TestCompletionCallback callback1; + ServerBoundCertService::RequestHandle request_handle1; + + std::string private_key_info2; + std::string der_cert2; + TestCompletionCallback callback2; + ServerBoundCertService::RequestHandle request_handle2; + + error = service_->GetOrCreateDomainBoundCert( + host, &private_key_info1, &der_cert1, + callback1.callback(), &request_handle1); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle1.is_active()); + // Should join with the original request. + error = service_->GetDomainBoundCert( + host, &private_key_info2, &der_cert2, callback2.callback(), + &request_handle2); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle2.is_active()); + + error = callback1.WaitForResult(); + EXPECT_EQ(OK, error); + error = callback2.WaitForResult(); + EXPECT_EQ(OK, error); + EXPECT_EQ(der_cert1, der_cert2); + + EXPECT_EQ(2u, service_->requests()); + EXPECT_EQ(0u, service_->cert_store_hits()); + EXPECT_EQ(1u, service_->inflight_joins()); + EXPECT_EQ(1u, service_->workers_created()); } TEST_F(ServerBoundCertServiceTest, ExtractValuesFromBytesEC) { @@ -263,7 +345,7 @@ TEST_F(ServerBoundCertServiceTest, ExtractValuesFromBytesEC) { TestCompletionCallback callback; ServerBoundCertService::RequestHandle request_handle; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info, &der_cert, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); @@ -297,18 +379,16 @@ TEST_F(ServerBoundCertServiceTest, CancelRequest) { int error; ServerBoundCertService::RequestHandle request_handle; - error = service_->GetDomainBoundCert(host, - &private_key_info, - &der_cert, - base::Bind(&FailTest), - &request_handle); + error = service_->GetOrCreateDomainBoundCert(host, + &private_key_info, + &der_cert, + base::Bind(&FailTest), + &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle.is_active()); request_handle.Cancel(); EXPECT_FALSE(request_handle.is_active()); - // Wait for generation to finish. - sequenced_worker_pool_->FlushForTesting(); // Wait for reply from ServerBoundCertServiceWorker to be posted back to the // ServerBoundCertService. base::MessageLoop::current()->RunUntilIdle(); @@ -326,17 +406,15 @@ TEST_F(ServerBoundCertServiceTest, CancelRequestByHandleDestruction) { { ServerBoundCertService::RequestHandle request_handle; - error = service_->GetDomainBoundCert(host, - &private_key_info, - &der_cert, - base::Bind(&FailTest), - &request_handle); + error = service_->GetOrCreateDomainBoundCert(host, + &private_key_info, + &der_cert, + base::Bind(&FailTest), + &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle.is_active()); } - // Wait for generation to finish. - sequenced_worker_pool_->FlushForTesting(); // Wait for reply from ServerBoundCertServiceWorker to be posted back to the // ServerBoundCertService. base::MessageLoop::current()->RunUntilIdle(); @@ -352,11 +430,11 @@ TEST_F(ServerBoundCertServiceTest, DestructionWithPendingRequest) { int error; ServerBoundCertService::RequestHandle request_handle; - error = service_->GetDomainBoundCert(host, - &private_key_info, - &der_cert, - base::Bind(&FailTest), - &request_handle); + error = service_->GetOrCreateDomainBoundCert(host, + &private_key_info, + &der_cert, + base::Bind(&FailTest), + &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle.is_active()); @@ -364,10 +442,8 @@ TEST_F(ServerBoundCertServiceTest, DestructionWithPendingRequest) { request_handle.Cancel(); service_.reset(); - // Wait for generation to finish. - sequenced_worker_pool_->FlushForTesting(); // ServerBoundCertServiceWorker should not post anything back to the - // non-existant ServerBoundCertService, but run the loop just to be sure it + // non-existent ServerBoundCertService, but run the loop just to be sure it // doesn't. base::MessageLoop::current()->RunUntilIdle(); @@ -378,12 +454,9 @@ TEST_F(ServerBoundCertServiceTest, DestructionWithPendingRequest) { // requests gracefully fails. // This is a regression test for http://crbug.com/236387 TEST_F(ServerBoundCertServiceTest, RequestAfterPoolShutdown) { - // Shutdown the pool immediately. - sequenced_worker_pool_->Shutdown(); - sequenced_worker_pool_ = NULL; - - // Ensure any shutdown code is processed. - base::MessageLoop::current()->RunUntilIdle(); + scoped_refptr<FailingTaskRunner> task_runner(new FailingTaskRunner); + service_.reset(new ServerBoundCertService( + new DefaultServerBoundCertStore(NULL), task_runner)); // Make a request that will force synchronous completion. std::string host("encrypted.google.com"); @@ -391,11 +464,11 @@ TEST_F(ServerBoundCertServiceTest, RequestAfterPoolShutdown) { int error; ServerBoundCertService::RequestHandle request_handle; - error = service_->GetDomainBoundCert(host, - &private_key_info, - &der_cert, - base::Bind(&FailTest), - &request_handle); + error = service_->GetOrCreateDomainBoundCert(host, + &private_key_info, + &der_cert, + base::Bind(&FailTest), + &request_handle); // If we got here without crashing or a valgrind error, it worked. ASSERT_EQ(ERR_INSUFFICIENT_RESOURCES, error); EXPECT_FALSE(request_handle.is_active()); @@ -420,27 +493,27 @@ TEST_F(ServerBoundCertServiceTest, SimultaneousCreation) { TestCompletionCallback callback3; ServerBoundCertService::RequestHandle request_handle3; - error = service_->GetDomainBoundCert(host1, - &private_key_info1, - &der_cert1, - callback1.callback(), - &request_handle1); + error = service_->GetOrCreateDomainBoundCert(host1, + &private_key_info1, + &der_cert1, + callback1.callback(), + &request_handle1); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle1.is_active()); - error = service_->GetDomainBoundCert(host2, - &private_key_info2, - &der_cert2, - callback2.callback(), - &request_handle2); + error = service_->GetOrCreateDomainBoundCert(host2, + &private_key_info2, + &der_cert2, + callback2.callback(), + &request_handle2); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle2.is_active()); - error = service_->GetDomainBoundCert(host3, - &private_key_info3, - &der_cert3, - callback3.callback(), - &request_handle3); + error = service_->GetOrCreateDomainBoundCert(host3, + &private_key_info3, + &der_cert3, + callback3.callback(), + &request_handle3); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle3.is_active()); @@ -492,7 +565,7 @@ TEST_F(ServerBoundCertServiceTest, Expiration) { // Cert is valid - synchronous completion. std::string private_key_info1, der_cert1; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( "good", &private_key_info1, &der_cert1, callback.callback(), &request_handle); EXPECT_EQ(OK, error); @@ -503,7 +576,7 @@ TEST_F(ServerBoundCertServiceTest, Expiration) { // Expired cert is valid as well - synchronous completion. std::string private_key_info2, der_cert2; - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( "expired", &private_key_info2, &der_cert2, callback.callback(), &request_handle); EXPECT_EQ(OK, error); @@ -513,11 +586,11 @@ TEST_F(ServerBoundCertServiceTest, Expiration) { EXPECT_STREQ("d", der_cert2.c_str()); } -TEST_F(ServerBoundCertServiceTest, AsyncStoreGetNoCertsInStore) { +TEST_F(ServerBoundCertServiceTest, AsyncStoreGetOrCreateNoCertsInStore) { MockServerBoundCertStoreWithAsyncGet* mock_store = new MockServerBoundCertStoreWithAsyncGet(); - service_ = scoped_ptr<ServerBoundCertService>( - new ServerBoundCertService(mock_store, sequenced_worker_pool_)); + service_ = scoped_ptr<ServerBoundCertService>(new ServerBoundCertService( + mock_store, base::MessageLoopProxy::current())); std::string host("encrypted.google.com"); @@ -528,7 +601,7 @@ TEST_F(ServerBoundCertServiceTest, AsyncStoreGetNoCertsInStore) { // Asynchronous completion with no certs in the store. std::string private_key_info, der_cert; EXPECT_EQ(0, service_->cert_count()); - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info, &der_cert, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle.is_active()); @@ -544,11 +617,42 @@ TEST_F(ServerBoundCertServiceTest, AsyncStoreGetNoCertsInStore) { EXPECT_FALSE(request_handle.is_active()); } -TEST_F(ServerBoundCertServiceTest, AsyncStoreGetOneCertInStore) { +TEST_F(ServerBoundCertServiceTest, AsyncStoreGetNoCertsInStore) { MockServerBoundCertStoreWithAsyncGet* mock_store = new MockServerBoundCertStoreWithAsyncGet(); - service_ = scoped_ptr<ServerBoundCertService>( - new ServerBoundCertService(mock_store, sequenced_worker_pool_)); + service_ = scoped_ptr<ServerBoundCertService>(new ServerBoundCertService( + mock_store, base::MessageLoopProxy::current())); + + std::string host("encrypted.google.com"); + + int error; + TestCompletionCallback callback; + ServerBoundCertService::RequestHandle request_handle; + + // Asynchronous completion with no certs in the store. + std::string private_key, der_cert; + EXPECT_EQ(0, service_->cert_count()); + error = service_->GetDomainBoundCert( + host, &private_key, &der_cert, callback.callback(), &request_handle); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle.is_active()); + + mock_store->CallGetServerBoundCertCallbackWithResult( + ERR_FILE_NOT_FOUND, base::Time(), std::string(), std::string()); + + error = callback.WaitForResult(); + EXPECT_EQ(ERR_FILE_NOT_FOUND, error); + EXPECT_EQ(0, service_->cert_count()); + EXPECT_EQ(0u, service_->workers_created()); + EXPECT_TRUE(der_cert.empty()); + EXPECT_FALSE(request_handle.is_active()); +} + +TEST_F(ServerBoundCertServiceTest, AsyncStoreGetOrCreateOneCertInStore) { + MockServerBoundCertStoreWithAsyncGet* mock_store = + new MockServerBoundCertStoreWithAsyncGet(); + service_ = scoped_ptr<ServerBoundCertService>(new ServerBoundCertService( + mock_store, base::MessageLoopProxy::current())); std::string host("encrypted.google.com"); @@ -559,7 +663,7 @@ TEST_F(ServerBoundCertServiceTest, AsyncStoreGetOneCertInStore) { // Asynchronous completion with a cert in the store. std::string private_key_info, der_cert; EXPECT_EQ(0, service_->cert_count()); - error = service_->GetDomainBoundCert( + error = service_->GetOrCreateDomainBoundCert( host, &private_key_info, &der_cert, callback.callback(), &request_handle); EXPECT_EQ(ERR_IO_PENDING, error); EXPECT_TRUE(request_handle.is_active()); @@ -580,6 +684,94 @@ TEST_F(ServerBoundCertServiceTest, AsyncStoreGetOneCertInStore) { EXPECT_FALSE(request_handle.is_active()); } +TEST_F(ServerBoundCertServiceTest, AsyncStoreGetOneCertInStore) { + MockServerBoundCertStoreWithAsyncGet* mock_store = + new MockServerBoundCertStoreWithAsyncGet(); + service_ = scoped_ptr<ServerBoundCertService>(new ServerBoundCertService( + mock_store, base::MessageLoopProxy::current())); + + std::string host("encrypted.google.com"); + + int error; + TestCompletionCallback callback; + ServerBoundCertService::RequestHandle request_handle; + + // Asynchronous completion with a cert in the store. + std::string private_key, der_cert; + EXPECT_EQ(0, service_->cert_count()); + error = service_->GetDomainBoundCert( + host, &private_key, &der_cert, callback.callback(), &request_handle); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle.is_active()); + + mock_store->CallGetServerBoundCertCallbackWithResult( + OK, base::Time(), "ab", "cd"); + + error = callback.WaitForResult(); + EXPECT_EQ(OK, error); + EXPECT_EQ(1, service_->cert_count()); + EXPECT_EQ(1u, service_->requests()); + EXPECT_EQ(1u, service_->cert_store_hits()); + // Because the cert was found in the store, no new workers should have been + // created. + EXPECT_EQ(0u, service_->workers_created()); + EXPECT_STREQ("cd", der_cert.c_str()); + EXPECT_FALSE(request_handle.is_active()); +} + +TEST_F(ServerBoundCertServiceTest, AsyncStoreGetThenCreateNoCertsInStore) { + MockServerBoundCertStoreWithAsyncGet* mock_store = + new MockServerBoundCertStoreWithAsyncGet(); + service_ = scoped_ptr<ServerBoundCertService>(new ServerBoundCertService( + mock_store, base::MessageLoopProxy::current())); + + std::string host("encrypted.google.com"); + + int error; + + // Asynchronous get with no certs in the store. + TestCompletionCallback callback1; + ServerBoundCertService::RequestHandle request_handle1; + std::string private_key1, der_cert1; + EXPECT_EQ(0, service_->cert_count()); + error = service_->GetDomainBoundCert( + host, &private_key1, &der_cert1, callback1.callback(), &request_handle1); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle1.is_active()); + + // Asynchronous get/create with no certs in the store. + TestCompletionCallback callback2; + ServerBoundCertService::RequestHandle request_handle2; + std::string private_key2, der_cert2; + EXPECT_EQ(0, service_->cert_count()); + error = service_->GetOrCreateDomainBoundCert( + host, &private_key2, &der_cert2, callback2.callback(), &request_handle2); + EXPECT_EQ(ERR_IO_PENDING, error); + EXPECT_TRUE(request_handle2.is_active()); + + mock_store->CallGetServerBoundCertCallbackWithResult( + ERR_FILE_NOT_FOUND, base::Time(), std::string(), std::string()); + + // Even though the first request didn't ask to create a cert, it gets joined + // by the second, which does, so both succeed. + error = callback1.WaitForResult(); + EXPECT_EQ(OK, error); + error = callback2.WaitForResult(); + EXPECT_EQ(OK, error); + + // One cert is created, one request is joined. + EXPECT_EQ(2U, service_->requests()); + EXPECT_EQ(1, service_->cert_count()); + EXPECT_EQ(1u, service_->workers_created()); + EXPECT_EQ(1u, service_->inflight_joins()); + EXPECT_FALSE(der_cert1.empty()); + EXPECT_EQ(der_cert1, der_cert2); + EXPECT_FALSE(private_key1.empty()); + EXPECT_EQ(private_key1, private_key2); + EXPECT_FALSE(request_handle1.is_active()); + EXPECT_FALSE(request_handle2.is_active()); +} + #endif // !defined(USE_OPENSSL) } // namespace diff --git a/chromium/net/ssl/ssl_cipher_suite_names.cc b/chromium/net/ssl/ssl_cipher_suite_names.cc index f12d017fe76..f9394dfd68c 100644 --- a/chromium/net/ssl/ssl_cipher_suite_names.cc +++ b/chromium/net/ssl/ssl_cipher_suite_names.cc @@ -194,6 +194,8 @@ static const struct CipherSuite kCipherSuites[] = { {0xc08b, 0x1087}, // TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 {0xc08c, 0xf7f}, // TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 {0xc08d, 0xf87}, // TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 + {0xcc13, 0x108f}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 + {0xcc14, 0x0d8f}, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 }; static const struct { @@ -220,8 +222,8 @@ static const struct { }; static const struct { - char name[17]; -} kCipherNames[17] = { + char name[18]; +} kCipherNames[18] = { {"NULL"}, // 0 {"RC4_40"}, // 1 {"RC4_128"}, // 2 @@ -239,6 +241,7 @@ static const struct { {"AES_256_GCM"}, // 14 {"CAMELLIA_128_GCM"}, // 15 {"CAMELLIA_256_GCM"}, // 16 + {"CHACHA20_POLY1305"}, // 17 }; static const struct { diff --git a/chromium/net/ssl/ssl_config_service.cc b/chromium/net/ssl/ssl_config_service.cc index 265b43c10e5..a2c34a26852 100644 --- a/chromium/net/ssl/ssl_config_service.cc +++ b/chromium/net/ssl/ssl_config_service.cc @@ -42,7 +42,7 @@ SSLConfig::SSLConfig() version_max(g_default_version_max), cached_info_enabled(false), channel_id_enabled(true), - false_start_enabled(false), + false_start_enabled(true), unrestricted_ssl3_fallback_enabled(false), send_client_cert(false), verify_ev_cert(false), diff --git a/chromium/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java b/chromium/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java deleted file mode 100644 index 9e60a43fa41..00000000000 --- a/chromium/net/test/android/javatests/src/org/chromium/net/test/util/TestWebServer.java +++ /dev/null @@ -1,557 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -package org.chromium.net.test.util; - -import android.util.Base64; -import android.util.Log; -import android.util.Pair; - -import org.apache.http.HttpException; -import org.apache.http.HttpRequest; -import org.apache.http.HttpResponse; -import org.apache.http.HttpStatus; -import org.apache.http.HttpVersion; -import org.apache.http.RequestLine; -import org.apache.http.StatusLine; -import org.apache.http.entity.ByteArrayEntity; -import org.apache.http.impl.DefaultHttpServerConnection; -import org.apache.http.impl.cookie.DateUtils; -import org.apache.http.message.BasicHttpResponse; -import org.apache.http.params.BasicHttpParams; -import org.apache.http.params.CoreProtocolPNames; -import org.apache.http.params.HttpParams; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.net.MalformedURLException; -import java.net.ServerSocket; -import java.net.Socket; -import java.net.URI; -import java.net.URL; -import java.net.URLConnection; -import java.security.KeyManagementException; -import java.security.KeyStore; -import java.security.NoSuchAlgorithmException; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.Date; -import java.util.HashMap; -import java.util.Hashtable; -import java.util.List; -import java.util.Map; - -import javax.net.ssl.HostnameVerifier; -import javax.net.ssl.HttpsURLConnection; -import javax.net.ssl.KeyManager; -import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.SSLSession; -import javax.net.ssl.X509TrustManager; - -/** - * Simple http test server for testing. - * - * This server runs in a thread in the current process, so it is convenient - * for loopback testing without the need to setup tcp forwarding to the - * host computer. - * - * Based heavily on the CTSWebServer in Android. - */ -public class TestWebServer { - private static final String TAG = "TestWebServer"; - private static final int SERVER_PORT = 4444; - private static final int SSL_SERVER_PORT = 4445; - - public static final String SHUTDOWN_PREFIX = "/shutdown"; - - private static TestWebServer sInstance; - private static Hashtable<Integer, String> sReasons; - - private final ServerThread mServerThread; - private String mServerUri; - private final boolean mSsl; - - private static class Response { - final byte[] mResponseData; - final List<Pair<String, String>> mResponseHeaders; - final boolean mIsRedirect; - - Response(byte[] resposneData, List<Pair<String, String>> responseHeaders, - boolean isRedirect) { - mIsRedirect = isRedirect; - mResponseData = resposneData; - mResponseHeaders = responseHeaders == null ? - new ArrayList<Pair<String, String>>() : responseHeaders; - } - } - - // The Maps below are modified on both the client thread and the internal server thread, so - // need to use a lock when accessing them. - private final Object mLock = new Object(); - private final Map<String, Response> mResponseMap = new HashMap<String, Response>(); - private final Map<String, Integer> mResponseCountMap = new HashMap<String, Integer>(); - private final Map<String, HttpRequest> mLastRequestMap = new HashMap<String, HttpRequest>(); - - /** - * Create and start a local HTTP server instance. - * @param ssl True if the server should be using secure sockets. - * @throws Exception - */ - public TestWebServer(boolean ssl) throws Exception { - if (sInstance != null) { - // attempt to start a new instance while one is still running - // shut down the old instance first - sInstance.shutdown(); - } - setStaticInstance(this); - mSsl = ssl; - if (mSsl) { - mServerUri = "https://localhost:" + SSL_SERVER_PORT; - } else { - mServerUri = "http://localhost:" + SERVER_PORT; - } - mServerThread = new ServerThread(this, mSsl); - mServerThread.start(); - } - - private static void setStaticInstance(TestWebServer instance) { - sInstance = instance; - } - - /** - * Terminate the http server. - */ - public void shutdown() { - try { - // Avoid a deadlock between two threads where one is trying to call - // close() and the other one is calling accept() by sending a GET - // request for shutdown and having the server's one thread - // sequentially call accept() and close(). - URL url = new URL(mServerUri + SHUTDOWN_PREFIX); - URLConnection connection = openConnection(url); - connection.connect(); - - // Read the input from the stream to send the request. - InputStream is = connection.getInputStream(); - is.close(); - - // Block until the server thread is done shutting down. - mServerThread.join(); - - } catch (MalformedURLException e) { - throw new IllegalStateException(e); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } catch (IOException e) { - throw new RuntimeException(e); - } catch (NoSuchAlgorithmException e) { - throw new IllegalStateException(e); - } catch (KeyManagementException e) { - throw new IllegalStateException(e); - } - - setStaticInstance(null); - } - - private final static int RESPONSE_STATUS_NORMAL = 0; - private final static int RESPONSE_STATUS_MOVED_TEMPORARILY = 1; - - private String setResponseInternal( - String requestPath, byte[] responseData, - List<Pair<String, String>> responseHeaders, - int status) { - final boolean isRedirect = (status == RESPONSE_STATUS_MOVED_TEMPORARILY); - - synchronized (mLock) { - mResponseMap.put(requestPath, new Response(responseData, responseHeaders, isRedirect)); - mResponseCountMap.put(requestPath, Integer.valueOf(0)); - mLastRequestMap.put(requestPath, null); - } - return getResponseUrl(requestPath); - } - - /** - * Gets the URL on the server under which a particular request path will be accessible. - * - * This only gets the URL, you still need to set the response if you intend to access it. - * - * @param requestPath The path to respond to. - * @return The full URL including the requestPath. - */ - public String getResponseUrl(String requestPath) { - return mServerUri + requestPath; - } - - /** - * Sets a response to be returned when a particular request path is passed - * in (with the option to specify additional headers). - * - * @param requestPath The path to respond to. - * @param responseString The response body that will be returned. - * @param responseHeaders Any additional headers that should be returned along with the - * response (null is acceptable). - * @return The full URL including the path that should be requested to get the expected - * response. - */ - public String setResponse( - String requestPath, String responseString, - List<Pair<String, String>> responseHeaders) { - return setResponseInternal(requestPath, responseString.getBytes(), responseHeaders, - RESPONSE_STATUS_NORMAL); - } - - /** - * Sets a redirect. - * - * @param requestPath The path to respond to. - * @param targetPath The path to redirect to. - * @return The full URL including the path that should be requested to get the expected - * response. - */ - public String setRedirect( - String requestPath, String targetPath) { - List<Pair<String, String>> responseHeaders = new ArrayList<Pair<String, String>>(); - responseHeaders.add(Pair.create("Location", targetPath)); - - return setResponseInternal(requestPath, targetPath.getBytes(), responseHeaders, - RESPONSE_STATUS_MOVED_TEMPORARILY); - } - - /** - * Sets a base64 encoded response to be returned when a particular request path is passed - * in (with the option to specify additional headers). - * - * @param requestPath The path to respond to. - * @param base64EncodedResponse The response body that is base64 encoded. The actual server - * response will the decoded binary form. - * @param responseHeaders Any additional headers that should be returned along with the - * response (null is acceptable). - * @return The full URL including the path that should be requested to get the expected - * response. - */ - public String setResponseBase64( - String requestPath, String base64EncodedResponse, - List<Pair<String, String>> responseHeaders) { - return setResponseInternal(requestPath, - Base64.decode(base64EncodedResponse, Base64.DEFAULT), - responseHeaders, - RESPONSE_STATUS_NORMAL); - } - - /** - * Get the number of requests was made at this path since it was last set. - */ - public int getRequestCount(String requestPath) { - Integer count = null; - synchronized (mLock) { - count = mResponseCountMap.get(requestPath); - } - if (count == null) throw new IllegalArgumentException("Path not set: " + requestPath); - return count.intValue(); - } - - /** - * Returns the last HttpRequest at this path. Can return null if it is never requested. - */ - public HttpRequest getLastRequest(String requestPath) { - synchronized (mLock) { - if (!mLastRequestMap.containsKey(requestPath)) - throw new IllegalArgumentException("Path not set: " + requestPath); - return mLastRequestMap.get(requestPath); - } - } - - public String getBaseUrl() { - return mServerUri + "/"; - } - - private URLConnection openConnection(URL url) - throws IOException, NoSuchAlgorithmException, KeyManagementException { - if (mSsl) { - // Install hostname verifiers and trust managers that don't do - // anything in order to get around the client not trusting - // the test server due to a lack of certificates. - - HttpsURLConnection connection = (HttpsURLConnection) url.openConnection(); - connection.setHostnameVerifier(new TestHostnameVerifier()); - - SSLContext context = SSLContext.getInstance("TLS"); - TestTrustManager trustManager = new TestTrustManager(); - context.init(null, new TestTrustManager[] {trustManager}, null); - connection.setSSLSocketFactory(context.getSocketFactory()); - - return connection; - } else { - return url.openConnection(); - } - } - - /** - * {@link X509TrustManager} that trusts everybody. This is used so that - * the client calling {@link TestWebServer#shutdown()} can issue a request - * for shutdown by blindly trusting the {@link TestWebServer}'s - * credentials. - */ - private static class TestTrustManager implements X509TrustManager { - @Override - public void checkClientTrusted(X509Certificate[] chain, String authType) { - // Trust the TestWebServer... - } - - @Override - public void checkServerTrusted(X509Certificate[] chain, String authType) { - // Trust the TestWebServer... - } - - @Override - public X509Certificate[] getAcceptedIssuers() { - return null; - } - } - - /** - * {@link HostnameVerifier} that verifies everybody. This permits - * the client to trust the web server and call - * {@link TestWebServer#shutdown()}. - */ - private static class TestHostnameVerifier implements HostnameVerifier { - @Override - public boolean verify(String hostname, SSLSession session) { - return true; - } - } - - private void servedResponseFor(String path, HttpRequest request) { - synchronized (mLock) { - mResponseCountMap.put(path, Integer.valueOf( - mResponseCountMap.get(path).intValue() + 1)); - mLastRequestMap.put(path, request); - } - } - - /** - * Generate a response to the given request. - * @throws InterruptedException - */ - private HttpResponse getResponse(HttpRequest request) throws InterruptedException { - RequestLine requestLine = request.getRequestLine(); - HttpResponse httpResponse = null; - Log.i(TAG, requestLine.getMethod() + ": " + requestLine.getUri()); - String uriString = requestLine.getUri(); - URI uri = URI.create(uriString); - String path = uri.getPath(); - - Response response = null; - synchronized (mLock) { - response = mResponseMap.get(path); - } - if (path.equals(SHUTDOWN_PREFIX)) { - httpResponse = createResponse(HttpStatus.SC_OK); - } else if (response == null) { - httpResponse = createResponse(HttpStatus.SC_NOT_FOUND); - } else if (response.mIsRedirect) { - httpResponse = createResponse(HttpStatus.SC_MOVED_TEMPORARILY); - for (Pair<String, String> header : response.mResponseHeaders) { - httpResponse.addHeader(header.first, header.second); - } - servedResponseFor(path, request); - } else { - httpResponse = createResponse(HttpStatus.SC_OK); - httpResponse.setEntity(createEntity(response.mResponseData)); - for (Pair<String, String> header : response.mResponseHeaders) { - httpResponse.addHeader(header.first, header.second); - } - servedResponseFor(path, request); - } - StatusLine sl = httpResponse.getStatusLine(); - Log.i(TAG, sl.getStatusCode() + "(" + sl.getReasonPhrase() + ")"); - setDateHeaders(httpResponse); - return httpResponse; - } - - private void setDateHeaders(HttpResponse response) { - response.addHeader("Date", DateUtils.formatDate(new Date(), DateUtils.PATTERN_RFC1123)); - } - - /** - * Create an empty response with the given status. - */ - private HttpResponse createResponse(int status) { - HttpResponse response = new BasicHttpResponse(HttpVersion.HTTP_1_0, status, null); - String reason = null; - - // This synchronized silences findbugs. - synchronized (TestWebServer.class) { - if (sReasons == null) { - sReasons = new Hashtable<Integer, String>(); - sReasons.put(HttpStatus.SC_UNAUTHORIZED, "Unauthorized"); - sReasons.put(HttpStatus.SC_NOT_FOUND, "Not Found"); - sReasons.put(HttpStatus.SC_FORBIDDEN, "Forbidden"); - sReasons.put(HttpStatus.SC_MOVED_TEMPORARILY, "Moved Temporarily"); - } - // Fill in error reason. Avoid use of the ReasonPhraseCatalog, which is - // Locale-dependent. - reason = sReasons.get(status); - } - - if (reason != null) { - StringBuffer buf = new StringBuffer("<html><head><title>"); - buf.append(reason); - buf.append("</title></head><body>"); - buf.append(reason); - buf.append("</body></html>"); - response.setEntity(createEntity(buf.toString().getBytes())); - } - return response; - } - - /** - * Create a string entity for the given content. - */ - private ByteArrayEntity createEntity(byte[] data) { - ByteArrayEntity entity = new ByteArrayEntity(data); - entity.setContentType("text/html"); - return entity; - } - - private static class ServerThread extends Thread { - private TestWebServer mServer; - private ServerSocket mSocket; - private boolean mIsSsl; - private boolean mIsCancelled; - private SSLContext mSslContext; - - /** - * Defines the keystore contents for the server, BKS version. Holds just a - * single self-generated key. The subject name is "Test Server". - */ - private static final String SERVER_KEYS_BKS = - "AAAAAQAAABQDkebzoP1XwqyWKRCJEpn/t8dqIQAABDkEAAVteWtleQAAARpYl20nAAAAAQAFWC41" + - "MDkAAAJNMIICSTCCAbKgAwIBAgIESEfU1jANBgkqhkiG9w0BAQUFADBpMQswCQYDVQQGEwJVUzET" + - "MBEGA1UECBMKQ2FsaWZvcm5pYTEMMAoGA1UEBxMDTVRWMQ8wDQYDVQQKEwZHb29nbGUxEDAOBgNV" + - "BAsTB0FuZHJvaWQxFDASBgNVBAMTC1Rlc3QgU2VydmVyMB4XDTA4MDYwNTExNTgxNFoXDTA4MDkw" + - "MzExNTgxNFowaTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExDDAKBgNVBAcTA01U" + - "VjEPMA0GA1UEChMGR29vZ2xlMRAwDgYDVQQLEwdBbmRyb2lkMRQwEgYDVQQDEwtUZXN0IFNlcnZl" + - "cjCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA0LIdKaIr9/vsTq8BZlA3R+NFWRaH4lGsTAQy" + - "DPMF9ZqEDOaL6DJuu0colSBBBQ85hQTPa9m9nyJoN3pEi1hgamqOvQIWcXBk+SOpUGRZZFXwniJV" + - "zDKU5nE9MYgn2B9AoiH3CSuMz6HRqgVaqtppIe1jhukMc/kHVJvlKRNy9XMCAwEAATANBgkqhkiG" + - "9w0BAQUFAAOBgQC7yBmJ9O/eWDGtSH9BH0R3dh2NdST3W9hNZ8hIa8U8klhNHbUCSSktZmZkvbPU" + - "hse5LI3dh6RyNDuqDrbYwcqzKbFJaq/jX9kCoeb3vgbQElMRX8D2ID1vRjxwlALFISrtaN4VpWzV" + - "yeoHPW4xldeZmoVtjn8zXNzQhLuBqX2MmAAAAqwAAAAUvkUScfw9yCSmALruURNmtBai7kQAAAZx" + - "4Jmijxs/l8EBaleaUru6EOPioWkUAEVWCxjM/TxbGHOi2VMsQWqRr/DZ3wsDmtQgw3QTrUK666sR" + - "MBnbqdnyCyvM1J2V1xxLXPUeRBmR2CXorYGF9Dye7NkgVdfA+9g9L/0Au6Ugn+2Cj5leoIgkgApN" + - "vuEcZegFlNOUPVEs3SlBgUF1BY6OBM0UBHTPwGGxFBBcetcuMRbUnu65vyDG0pslT59qpaR0TMVs" + - "P+tcheEzhyjbfM32/vwhnL9dBEgM8qMt0sqF6itNOQU/F4WGkK2Cm2v4CYEyKYw325fEhzTXosck" + - "MhbqmcyLab8EPceWF3dweoUT76+jEZx8lV2dapR+CmczQI43tV9btsd1xiBbBHAKvymm9Ep9bPzM" + - "J0MQi+OtURL9Lxke/70/MRueqbPeUlOaGvANTmXQD2OnW7PISwJ9lpeLfTG0LcqkoqkbtLKQLYHI" + - "rQfV5j0j+wmvmpMxzjN3uvNajLa4zQ8l0Eok9SFaRr2RL0gN8Q2JegfOL4pUiHPsh64WWya2NB7f" + - "V+1s65eA5ospXYsShRjo046QhGTmymwXXzdzuxu8IlnTEont6P4+J+GsWk6cldGbl20hctuUKzyx" + - "OptjEPOKejV60iDCYGmHbCWAzQ8h5MILV82IclzNViZmzAapeeCnexhpXhWTs+xDEYSKEiG/camt" + - "bhmZc3BcyVJrW23PktSfpBQ6D8ZxoMfF0L7V2GQMaUg+3r7ucrx82kpqotjv0xHghNIm95aBr1Qw" + - "1gaEjsC/0wGmmBDg1dTDH+F1p9TInzr3EFuYD0YiQ7YlAHq3cPuyGoLXJ5dXYuSBfhDXJSeddUkl" + - "k1ufZyOOcskeInQge7jzaRfmKg3U94r+spMEvb0AzDQVOKvjjo1ivxMSgFRZaDb/4qw="; - - private static final String PASSWORD = "android"; - - /** - * Loads a keystore from a base64-encoded String. Returns the KeyManager[] - * for the result. - */ - private KeyManager[] getKeyManagers() throws Exception { - byte[] bytes = Base64.decode(SERVER_KEYS_BKS, Base64.DEFAULT); - InputStream inputStream = new ByteArrayInputStream(bytes); - - KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); - keyStore.load(inputStream, PASSWORD.toCharArray()); - inputStream.close(); - - String algorithm = KeyManagerFactory.getDefaultAlgorithm(); - KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(algorithm); - keyManagerFactory.init(keyStore, PASSWORD.toCharArray()); - - return keyManagerFactory.getKeyManagers(); - } - - - public ServerThread(TestWebServer server, boolean ssl) throws Exception { - super("ServerThread"); - mServer = server; - mIsSsl = ssl; - int retry = 3; - while (true) { - try { - if (mIsSsl) { - mSslContext = SSLContext.getInstance("TLS"); - mSslContext.init(getKeyManagers(), null, null); - mSocket = mSslContext.getServerSocketFactory().createServerSocket( - SSL_SERVER_PORT); - } else { - mSocket = new ServerSocket(SERVER_PORT); - } - return; - } catch (IOException e) { - Log.w(TAG, e); - if (--retry == 0) { - throw e; - } - // sleep in case server socket is still being closed - Thread.sleep(1000); - } - } - } - - @Override - public void run() { - HttpParams params = new BasicHttpParams(); - params.setParameter(CoreProtocolPNames.PROTOCOL_VERSION, HttpVersion.HTTP_1_0); - while (!mIsCancelled) { - try { - Socket socket = mSocket.accept(); - DefaultHttpServerConnection conn = new DefaultHttpServerConnection(); - conn.bind(socket, params); - - // Determine whether we need to shutdown early before - // parsing the response since conn.close() will crash - // for SSL requests due to UnsupportedOperationException. - HttpRequest request = conn.receiveRequestHeader(); - if (isShutdownRequest(request)) { - mIsCancelled = true; - } - - HttpResponse response = mServer.getResponse(request); - conn.sendResponseHeader(response); - conn.sendResponseEntity(response); - conn.close(); - - } catch (IOException e) { - // normal during shutdown, ignore - Log.w(TAG, e); - } catch (HttpException e) { - Log.w(TAG, e); - } catch (InterruptedException e) { - Log.w(TAG, e); - } catch (UnsupportedOperationException e) { - // DefaultHttpServerConnection's close() throws an - // UnsupportedOperationException. - Log.w(TAG, e); - } - } - try { - mSocket.close(); - } catch (IOException ignored) { - // safe to ignore - } - } - - private boolean isShutdownRequest(HttpRequest request) { - RequestLine requestLine = request.getRequestLine(); - String uriString = requestLine.getUri(); - URI uri = URI.create(uriString); - String path = uri.getPath(); - return path.equals(SHUTDOWN_PREFIX); - } - } -} diff --git a/chromium/net/test/cert_test_util.cc b/chromium/net/test/cert_test_util.cc index 085a4594c76..5ec07749d50 100644 --- a/chromium/net/test/cert_test_util.cc +++ b/chromium/net/test/cert_test_util.cc @@ -19,7 +19,7 @@ CertificateList CreateCertificateListFromFile( int format) { base::FilePath cert_path = certs_dir.AppendASCII(cert_file); std::string cert_data; - if (!file_util::ReadFileToString(cert_path, &cert_data)) + if (!base::ReadFileToString(cert_path, &cert_data)) return CertificateList(); return X509Certificate::CreateCertificateListFromBytes(cert_data.data(), cert_data.size(), @@ -31,7 +31,7 @@ scoped_refptr<X509Certificate> ImportCertFromFile( const std::string& cert_file) { base::FilePath cert_path = certs_dir.AppendASCII(cert_file); std::string cert_data; - if (!file_util::ReadFileToString(cert_path, &cert_data)) + if (!base::ReadFileToString(cert_path, &cert_data)) return NULL; CertificateList certs_in_file = diff --git a/chromium/net/test/embedded_test_server/embedded_test_server.cc b/chromium/net/test/embedded_test_server/embedded_test_server.cc index 9175d6ca894..07bad2813fa 100644 --- a/chromium/net/test/embedded_test_server/embedded_test_server.cc +++ b/chromium/net/test/embedded_test_server/embedded_test_server.cc @@ -5,8 +5,9 @@ #include "net/test/embedded_test_server/embedded_test_server.h" #include "base/bind.h" -#include "base/files/file_path.h" #include "base/file_util.h" +#include "base/files/file_path.h" +#include "base/message_loop/message_loop.h" #include "base/path_service.h" #include "base/run_loop.h" #include "base/stl_util.h" @@ -59,7 +60,7 @@ scoped_ptr<HttpResponse> HandleFileRequest( base::FilePath file_path(server_root.AppendASCII(request_path)); std::string file_contents; - if (!file_util::ReadFileToString(file_path, &file_contents)) + if (!base::ReadFileToString(file_path, &file_contents)) return scoped_ptr<HttpResponse>(); base::FilePath headers_path( @@ -67,7 +68,7 @@ scoped_ptr<HttpResponse> HandleFileRequest( if (base::PathExists(headers_path)) { std::string headers_contents; - if (!file_util::ReadFileToString(headers_path, &headers_contents)) + if (!base::ReadFileToString(headers_path, &headers_contents)) return scoped_ptr<HttpResponse>(); scoped_ptr<CustomHttpResponse> http_response( @@ -118,15 +119,10 @@ EmbeddedTestServer::~EmbeddedTestServer() { bool EmbeddedTestServer::InitializeAndWaitUntilReady() { DCHECK(thread_checker_.CalledOnValidThread()); - base::RunLoop run_loop; - if (!io_thread_->PostTaskAndReply( - FROM_HERE, - base::Bind(&EmbeddedTestServer::InitializeOnIOThread, - base::Unretained(this)), - run_loop.QuitClosure())) { + if (!PostTaskToIOThreadAndWait(base::Bind( + &EmbeddedTestServer::InitializeOnIOThread, base::Unretained(this)))) { return false; } - run_loop.Run(); return Started() && base_url_.is_valid(); } @@ -134,17 +130,8 @@ bool EmbeddedTestServer::InitializeAndWaitUntilReady() { bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { DCHECK(thread_checker_.CalledOnValidThread()); - base::RunLoop run_loop; - if (!io_thread_->PostTaskAndReply( - FROM_HERE, - base::Bind(&EmbeddedTestServer::ShutdownOnIOThread, - base::Unretained(this)), - run_loop.QuitClosure())) { - return false; - } - run_loop.Run(); - - return true; + return PostTaskToIOThreadAndWait(base::Bind( + &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this))); } void EmbeddedTestServer::InitializeOnIOThread() { @@ -153,10 +140,10 @@ void EmbeddedTestServer::InitializeOnIOThread() { SocketDescriptor socket_descriptor = TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_); - if (socket_descriptor == TCPListenSocket::kInvalidSocket) + if (socket_descriptor == kInvalidSocket) return; - listen_socket_ = new HttpListenSocket(socket_descriptor, this); + listen_socket_.reset(new HttpListenSocket(socket_descriptor, this)); listen_socket_->Listen(); IPEndPoint address; @@ -171,7 +158,7 @@ void EmbeddedTestServer::InitializeOnIOThread() { void EmbeddedTestServer::ShutdownOnIOThread() { DCHECK(io_thread_->BelongsToCurrentThread()); - listen_socket_ = NULL; // Release the listen socket. + listen_socket_.reset(); STLDeleteContainerPairSecondPointers(connections_.begin(), connections_.end()); connections_.clear(); @@ -224,15 +211,17 @@ void EmbeddedTestServer::RegisterRequestHandler( request_handlers_.push_back(callback); } -void EmbeddedTestServer::DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) { +void EmbeddedTestServer::DidAccept( + StreamListenSocket* server, + scoped_ptr<StreamListenSocket> connection) { DCHECK(io_thread_->BelongsToCurrentThread()); HttpConnection* http_connection = new HttpConnection( - connection, + connection.Pass(), base::Bind(&EmbeddedTestServer::HandleRequest, weak_factory_.GetWeakPtr())); - connections_[connection] = http_connection; + // TODO(szym): Make HttpConnection the StreamListenSocket delegate. + connections_[http_connection->socket_.get()] = http_connection; } void EmbeddedTestServer::DidRead(StreamListenSocket* connection, @@ -272,5 +261,27 @@ HttpConnection* EmbeddedTestServer::FindConnection( return it->second; } +bool EmbeddedTestServer::PostTaskToIOThreadAndWait( + const base::Closure& closure) { + // Note that PostTaskAndReply below requires base::MessageLoopProxy::current() + // to return a loop for posting the reply task. However, in order to make + // EmbeddedTestServer universally usable, it needs to cope with the situation + // where it's running on a thread on which a message loop is not (yet) + // available or as has been destroyed already. + // + // To handle this situation, create temporary message loop to support the + // PostTaskAndReply operation if the current thread as no message loop. + scoped_ptr<base::MessageLoop> temporary_loop; + if (!base::MessageLoop::current()) + temporary_loop.reset(new base::MessageLoop()); + + base::RunLoop run_loop; + if (!io_thread_->PostTaskAndReply(FROM_HERE, closure, run_loop.QuitClosure())) + return false; + run_loop.Run(); + + return true; +} + } // namespace test_server } // namespace net diff --git a/chromium/net/test/embedded_test_server/embedded_test_server.h b/chromium/net/test/embedded_test_server/embedded_test_server.h index 879c4a947f9..f8c1bb5b3cb 100644 --- a/chromium/net/test/embedded_test_server/embedded_test_server.h +++ b/chromium/net/test/embedded_test_server/embedded_test_server.h @@ -10,6 +10,7 @@ #include <vector> #include "base/basictypes.h" +#include "base/callback.h" #include "base/compiler_specific.h" #include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" @@ -33,10 +34,10 @@ class HttpListenSocket : public TCPListenSocket { public: HttpListenSocket(const SocketDescriptor socket_descriptor, StreamListenSocket::Delegate* delegate); + virtual ~HttpListenSocket(); virtual void Listen(); private: - virtual ~HttpListenSocket(); base::ThreadChecker thread_checker_; }; @@ -53,7 +54,7 @@ class HttpListenSocket : public TCPListenSocket { // // void SetUp() { // base::Thread::Options thread_options; -// thread_options.message_loop_type = MessageLoop::TYPE_IO; +// thread_options.message_loop_type = base::MessageLoop::TYPE_IO; // ASSERT_TRUE(io_thread_.StartWithOptions(thread_options)); // // test_server_.reset( @@ -137,7 +138,7 @@ class EmbeddedTestServer : public StreamListenSocket::Delegate { // StreamListenSocket::Delegate overrides: virtual void DidAccept(StreamListenSocket* server, - StreamListenSocket* connection) OVERRIDE; + scoped_ptr<StreamListenSocket> connection) OVERRIDE; virtual void DidRead(StreamListenSocket* connection, const char* data, int length) OVERRIDE; @@ -145,9 +146,13 @@ class EmbeddedTestServer : public StreamListenSocket::Delegate { HttpConnection* FindConnection(StreamListenSocket* socket); + // Posts a task to the |io_thread_| and waits for a reply. + bool PostTaskToIOThreadAndWait( + const base::Closure& closure) WARN_UNUSED_RESULT; + scoped_refptr<base::SingleThreadTaskRunner> io_thread_; - scoped_refptr<HttpListenSocket> listen_socket_; + scoped_ptr<HttpListenSocket> listen_socket_; int port_; GURL base_url_; diff --git a/chromium/net/test/embedded_test_server/embedded_test_server_unittest.cc b/chromium/net/test/embedded_test_server/embedded_test_server_unittest.cc index 35d0fd414e1..2c005186607 100644 --- a/chromium/net/test/embedded_test_server/embedded_test_server_unittest.cc +++ b/chromium/net/test/embedded_test_server/embedded_test_server_unittest.cc @@ -40,8 +40,8 @@ std::string GetContentTypeFromFetcher(const URLFetcher& fetcher) { } // namespace -class EmbeddedTestServerTest : public testing::Test, - public URLFetcherDelegate { +class EmbeddedTestServerTest: public testing::Test, + public URLFetcherDelegate { public: EmbeddedTestServerTest() : num_responses_received_(0), @@ -240,5 +240,90 @@ TEST_F(EmbeddedTestServerTest, ConcurrentFetches) { EXPECT_EQ("text/plain", GetContentTypeFromFetcher(*fetcher3)); } +// Below test exercises EmbeddedTestServer's ability to cope with the situation +// where there is no MessageLoop available on the thread at EmbeddedTestServer +// initialization and/or destruction. + +typedef std::tr1::tuple<bool, bool> ThreadingTestParams; + +class EmbeddedTestServerThreadingTest + : public testing::TestWithParam<ThreadingTestParams> {}; + +class EmbeddedTestServerThreadingTestDelegate + : public base::PlatformThread::Delegate, + public URLFetcherDelegate { + public: + EmbeddedTestServerThreadingTestDelegate( + bool message_loop_present_on_initialize, + bool message_loop_present_on_shutdown) + : message_loop_present_on_initialize_(message_loop_present_on_initialize), + message_loop_present_on_shutdown_(message_loop_present_on_shutdown) {} + + // base::PlatformThread::Delegate: + virtual void ThreadMain() OVERRIDE { + scoped_refptr<base::SingleThreadTaskRunner> io_thread_runner; + base::Thread io_thread("io_thread"); + base::Thread::Options thread_options; + thread_options.message_loop_type = base::MessageLoop::TYPE_IO; + ASSERT_TRUE(io_thread.StartWithOptions(thread_options)); + io_thread_runner = io_thread.message_loop_proxy(); + + scoped_ptr<base::MessageLoop> loop; + if (message_loop_present_on_initialize_) + loop.reset(new base::MessageLoop(base::MessageLoop::TYPE_IO)); + + // Create the test server instance. + EmbeddedTestServer server(io_thread_runner); + base::FilePath src_dir; + ASSERT_TRUE(PathService::Get(base::DIR_SOURCE_ROOT, &src_dir)); + ASSERT_TRUE(server.InitializeAndWaitUntilReady()); + + // Make a request and wait for the reply. + if (!loop) + loop.reset(new base::MessageLoop(base::MessageLoop::TYPE_IO)); + + scoped_ptr<URLFetcher> fetcher(URLFetcher::Create( + server.GetURL("/test?q=foo"), URLFetcher::GET, this)); + fetcher->SetRequestContext( + new TestURLRequestContextGetter(loop->message_loop_proxy())); + fetcher->Start(); + loop->Run(); + fetcher.reset(); + + // Shut down. + if (message_loop_present_on_shutdown_) + loop.reset(); + + ASSERT_TRUE(server.ShutdownAndWaitUntilComplete()); + } + + // URLFetcherDelegate override. + virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE { + base::MessageLoop::current()->Quit(); + } + + private: + bool message_loop_present_on_initialize_; + bool message_loop_present_on_shutdown_; + + DISALLOW_COPY_AND_ASSIGN(EmbeddedTestServerThreadingTestDelegate); +}; + +TEST_P(EmbeddedTestServerThreadingTest, RunTest) { + // The actual test runs on a separate thread so it can screw with the presence + // of a MessageLoop - the test suite already sets up a MessageLoop for the + // main test thread. + base::PlatformThreadHandle thread_handle; + EmbeddedTestServerThreadingTestDelegate delegate( + std::tr1::get<0>(GetParam()), + std::tr1::get<1>(GetParam())); + ASSERT_TRUE(base::PlatformThread::Create(0, &delegate, &thread_handle)); + base::PlatformThread::Join(thread_handle); +} + +INSTANTIATE_TEST_CASE_P(EmbeddedTestServerThreadingTestInstantiation, + EmbeddedTestServerThreadingTest, + testing::Combine(testing::Bool(), testing::Bool())); + } // namespace test_server } // namespace net diff --git a/chromium/net/test/embedded_test_server/http_connection.cc b/chromium/net/test/embedded_test_server/http_connection.cc index 8b5317e320b..b7eab2ebd3b 100644 --- a/chromium/net/test/embedded_test_server/http_connection.cc +++ b/chromium/net/test/embedded_test_server/http_connection.cc @@ -10,9 +10,9 @@ namespace net { namespace test_server { -HttpConnection::HttpConnection(StreamListenSocket* socket, +HttpConnection::HttpConnection(scoped_ptr<StreamListenSocket> socket, const HandleRequestCallback& callback) - : socket_(socket), + : socket_(socket.Pass()), callback_(callback) { } diff --git a/chromium/net/test/embedded_test_server/http_connection.h b/chromium/net/test/embedded_test_server/http_connection.h index da9353404c6..870d12269dd 100644 --- a/chromium/net/test/embedded_test_server/http_connection.h +++ b/chromium/net/test/embedded_test_server/http_connection.h @@ -7,7 +7,6 @@ #include "base/basictypes.h" #include "base/callback.h" -#include "base/memory/ref_counted.h" #include "base/strings/string_piece.h" #include "net/test/embedded_test_server/http_request.h" @@ -30,7 +29,7 @@ typedef base::Callback<void(HttpConnection* connection, // If a valid request is parsed, then |callback_| is invoked. class HttpConnection { public: - HttpConnection(StreamListenSocket* socket, + HttpConnection(scoped_ptr<StreamListenSocket> socket, const HandleRequestCallback& callback); ~HttpConnection(); @@ -45,7 +44,7 @@ class HttpConnection { // called. void ReceiveData(const base::StringPiece& data); - scoped_refptr<StreamListenSocket> socket_; + scoped_ptr<StreamListenSocket> socket_; const HandleRequestCallback callback_; HttpRequestParser request_parser_; diff --git a/chromium/net/test/python_utils.cc b/chromium/net/test/python_utils.cc index 30c5f679a8a..a9ac98b5317 100644 --- a/chromium/net/test/python_utils.cc +++ b/chromium/net/test/python_utils.cc @@ -106,18 +106,8 @@ bool GetPyProtoPath(base::FilePath* dir) { bool GetPythonCommand(CommandLine* python_cmd) { DCHECK(python_cmd); - base::FilePath dir; -#if defined(OS_WIN) - if (!PathService::Get(base::DIR_SOURCE_ROOT, &dir)) - return false; - dir = dir.Append(FILE_PATH_LITERAL("third_party")) - .Append(FILE_PATH_LITERAL("python_26")) - .Append(FILE_PATH_LITERAL("python.exe")); -#elif defined(OS_POSIX) - dir = base::FilePath("python"); -#endif - python_cmd->SetProgram(dir); + python_cmd->SetProgram(base::FilePath(FILE_PATH_LITERAL("python"))); // Launch python in unbuffered mode, so that python output doesn't mix with // gtest output in buildbot log files. See http://crbug.com/147368. diff --git a/chromium/net/test/run_all_unittests.cc b/chromium/net/test/run_all_unittests.cc index d8392ff0e2a..07aaeb63b5d 100644 --- a/chromium/net/test/run_all_unittests.cc +++ b/chromium/net/test/run_all_unittests.cc @@ -3,6 +3,7 @@ // found in the LICENSE file. #include "base/metrics/statistics_recorder.h" +#include "base/test/unit_test_launcher.h" #include "build/build_config.h" #include "crypto/nss_util.h" #include "net/socket/client_socket_pool_base.h" @@ -49,5 +50,7 @@ int main(int argc, char** argv) { net::ProxyResolverV8::RememberDefaultIsolate(); #endif - return test_suite.Run(); + return base::LaunchUnitTests( + argc, argv, base::Bind(&NetTestSuite::Run, + base::Unretained(&test_suite))); } diff --git a/chromium/net/test/spawned_test_server/base_test_server.cc b/chromium/net/test/spawned_test_server/base_test_server.cc index c13745bf51e..b8697d49f51 100644 --- a/chromium/net/test/spawned_test_server/base_test_server.cc +++ b/chromium/net/test/spawned_test_server/base_test_server.cc @@ -171,7 +171,11 @@ bool BaseTestServer::GetAddressList(AddressList* address_list) const { scoped_ptr<HostResolver> resolver(HostResolver::CreateDefaultResolver(NULL)); HostResolver::RequestInfo info(host_port_pair_); TestCompletionCallback callback; - int rv = resolver->Resolve(info, address_list, callback.callback(), NULL, + int rv = resolver->Resolve(info, + DEFAULT_PRIORITY, + address_list, + callback.callback(), + NULL, BoundNetLog()); if (rv == ERR_IO_PENDING) rv = callback.WaitForResult(); diff --git a/chromium/net/test/spawned_test_server/local_test_server_win.cc b/chromium/net/test/spawned_test_server/local_test_server_win.cc index fd26483ed47..08d68f80bf2 100644 --- a/chromium/net/test/spawned_test_server/local_test_server_win.cc +++ b/chromium/net/test/spawned_test_server/local_test_server_win.cc @@ -10,6 +10,7 @@ #include "base/base_paths.h" #include "base/bind.h" #include "base/command_line.h" +#include "base/environment.h" #include "base/files/file_path.h" #include "base/message_loop/message_loop.h" #include "base/path_service.h" @@ -81,6 +82,60 @@ bool ReadData(HANDLE read_fd, HANDLE write_fd, return true; } +// Class that sets up a temporary path that includes the supplied path +// at the end. +// +// TODO(bratell): By making this more generic we can possibly reuse +// it at other places such as +// chrome/common/multi_process_lock_unittest.cc. +class ScopedPath { + public: + // Constructor which sets up the environment to include the path to + // |path_to_add|. + explicit ScopedPath(const base::FilePath& path_to_add); + + // Destructor that restores the path that were active when the + // object was constructed. + ~ScopedPath(); + + private: + // The PATH environment variable before it was changed or an empty + // string if there was no PATH environment variable. + std::string old_path_; + + // The helper object that allows us to read and set environment + // variables more easily. + scoped_ptr<base::Environment> environment_; + + // A flag saying if we have actually modified the environment. + bool path_modified_; + + DISALLOW_COPY_AND_ASSIGN(ScopedPath); +}; + +ScopedPath::ScopedPath(const base::FilePath& path_to_add) + : environment_(base::Environment::Create()), + path_modified_(false) { + environment_->GetVar("PATH", &old_path_); + + std::string new_value = old_path_; + if (!new_value.empty()) + new_value += ";"; + + new_value += WideToUTF8(path_to_add.value()); + + path_modified_ = environment_->SetVar("PATH", new_value); +} + +ScopedPath::~ScopedPath() { + if (!path_modified_) + return; + if (old_path_.empty()) + environment_->UnSetVar("PATH"); + else + environment_->SetVar("PATH", old_path_); +} + } // namespace namespace net { @@ -129,11 +184,21 @@ bool LocalTestServer::LaunchPython(const base::FilePath& testserver_path) { return false; } - if (!base::SetJobObjectAsKillOnJobClose(job_handle_.Get())) { - LOG(ERROR) << "Could not SetInformationJobObject."; + if (!base::SetJobObjectLimitFlags(job_handle_.Get(), + JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE)) { + LOG(ERROR) << "Could not SetJobObjectLimitFlags."; return false; } + // Add our internal python to the path so it can be used if there is + // no system python. + base::FilePath python_dir; + if (!PathService::Get(base::DIR_SOURCE_ROOT, &python_dir)) { + LOG(ERROR) << "Could not locate source root directory."; + return false; + } + python_dir = python_dir.AppendASCII("third_party").AppendASCII("python_26"); + ScopedPath python_path(python_dir); base::LaunchOptions launch_options; launch_options.inherit_handles = true; launch_options.job_handle = job_handle_.Get(); diff --git a/chromium/net/test/spawned_test_server/remote_test_server.cc b/chromium/net/test/spawned_test_server/remote_test_server.cc index a3b4ef3fdf5..594c5d488b4 100644 --- a/chromium/net/test/spawned_test_server/remote_test_server.cc +++ b/chromium/net/test/spawned_test_server/remote_test_server.cc @@ -169,8 +169,7 @@ bool RemoteTestServer::Init(const base::FilePath& document_root) { // Parse file to extract the ports information. std::string port_info; - if (!file_util::ReadFileToString(GetTestServerPortInfoFile(), - &port_info) || + if (!base::ReadFileToString(GetTestServerPortInfoFile(), &port_info) || port_info.empty()) { return false; } diff --git a/chromium/net/third_party/nss/README.chromium b/chromium/net/third_party/nss/README.chromium index 69b7eb77316..4a0f2d3ec19 100644 --- a/chromium/net/third_party/nss/README.chromium +++ b/chromium/net/third_party/nss/README.chromium @@ -31,6 +31,7 @@ Patches: * Add the SSL_PeerCertificateChain function patches/peercertchain.patch + patches/peercertchain2.patch https://bugzilla.mozilla.org/show_bug.cgi?id=731485 * Add support for client auth with native crypto APIs on Mac and Windows @@ -74,6 +75,8 @@ Patches: NSS that doesn't contain the CBC constant-time changes. patches/cbc.patch https://code.google.com/p/chromium/issues/detail?id=172658#c12 + TODO(wtc): remove this patch now that NSS 3.14.3 is the minimum + compile-time and run-time version. * Change ssl3_SuiteBOnly to always return PR_TRUE. The softoken in NSS versions older than 3.15 report an EC key size range of 112 bits to 571 @@ -97,11 +100,45 @@ Patches: https://bugzilla.mozilla.org/show_bug.cgi?id=903565 patches/sslsock_903565.patch + * Implement the AES GCM cipher suites. + https://bugzilla.mozilla.org/show_bug.cgi?id=880543 + patches/aesgcm.patch + + * Add Chromium-specific code to detect AES GCM support in the system NSS + libraries at run time. + patches/aesgcmchromium.patch + * Prefer to generate SHA-1 signatures for TLS 1.2 client authentication if the client private key is in a CAPI service provider on Windows or if the client private key is a 1024-bit RSA or DSA key. patches/tls12backuphash.patch + * Support ChaCha20+Poly1305 ciphersuites + http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-01 + patches/chacha20poly1305.patch + + * Fix session cache lock creation race. + patches/cachelocks.patch + https://bugzilla.mozilla.org/show_bug.cgi?id=764646 + + * Don't advertise TLS 1.2-only cipher suites in a TLS 1.1 ClientHello. + https://bugzilla.mozilla.org/show_bug.cgi?id=919677 + patches/ciphersuiteversion.patch + + * Don't use record versions greater than 0x0301 in resumption ClientHello + records either. + https://bugzilla.mozilla.org/show_bug.cgi?id=923696 + https://code.google.com/p/chromium/issues/detail?id=303398 + patches/resumeclienthelloversion.patch + + * Make SSL False Start work with asynchronous certificate validation. + https://bugzilla.mozilla.org/show_bug.cgi?id=713933 + patches/canfalsestart.patch + + * Have the Null Cipher limit output to the maximum allowed + https://bugzilla.mozilla.org/show_bug.cgi?id=934016 + patches/nullcipher_934016.patch + Apply the patches to NSS by running the patches/applypatches.sh script. Read the comments at the top of patches/applypatches.sh for instructions. diff --git a/chromium/net/third_party/nss/patches/aesgcm.patch b/chromium/net/third_party/nss/patches/aesgcm.patch new file mode 100644 index 00000000000..03fdf8e255a --- /dev/null +++ b/chromium/net/third_party/nss/patches/aesgcm.patch @@ -0,0 +1,1363 @@ +Index: net/third_party/nss/ssl/sslinfo.c +=================================================================== +--- net/third_party/nss/ssl/sslinfo.c (revision 217715) ++++ net/third_party/nss/ssl/sslinfo.c (working copy) +@@ -109,7 +109,7 @@ + #define K_ECDHE "ECDHE", kt_ecdh + + #define C_SEED "SEED", calg_seed +-#define C_CAMELLIA "CAMELLIA", calg_camellia ++#define C_CAMELLIA "CAMELLIA", calg_camellia + #define C_AES "AES", calg_aes + #define C_RC4 "RC4", calg_rc4 + #define C_RC2 "RC2", calg_rc2 +@@ -117,6 +117,7 @@ + #define C_3DES "3DES", calg_3des + #define C_NULL "NULL", calg_null + #define C_SJ "SKIPJACK", calg_sj ++#define C_AESGCM "AES-GCM", calg_aes_gcm + + #define B_256 256, 256, 256 + #define B_128 128, 128, 128 +@@ -127,12 +128,16 @@ + #define B_40 128, 40, 40 + #define B_0 0, 0, 0 + ++#define M_AEAD_128 "AEAD", ssl_mac_aead, 128 + #define M_SHA256 "SHA256", ssl_hmac_sha256, 256 + #define M_SHA "SHA1", ssl_mac_sha, 160 + #define M_MD5 "MD5", ssl_mac_md5, 128 ++#define M_NULL "NULL", ssl_mac_null, 0 + + static const SSLCipherSuiteInfo suiteInfo[] = { + /* <------ Cipher suite --------------------> <auth> <KEA> <bulk cipher> <MAC> <FIPS> */ ++{0,CS(TLS_RSA_WITH_AES_128_GCM_SHA256), S_RSA, K_RSA, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, ++ + {0,CS(TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA), S_RSA, K_DHE, C_CAMELLIA, B_256, M_SHA, 0, 0, 0, }, + {0,CS(TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA), S_DSA, K_DHE, C_CAMELLIA, B_256, M_SHA, 0, 0, 0, }, + {0,CS(TLS_DHE_RSA_WITH_AES_256_CBC_SHA256), S_RSA, K_DHE, C_AES, B_256, M_SHA256, 1, 0, 0, }, +@@ -146,6 +151,7 @@ + {0,CS(TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA), S_DSA, K_DHE, C_CAMELLIA, B_128, M_SHA, 0, 0, 0, }, + {0,CS(TLS_DHE_DSS_WITH_RC4_128_SHA), S_DSA, K_DHE, C_RC4, B_128, M_SHA, 0, 0, 0, }, + {0,CS(TLS_DHE_RSA_WITH_AES_128_CBC_SHA256), S_RSA, K_DHE, C_AES, B_128, M_SHA256, 1, 0, 0, }, ++{0,CS(TLS_DHE_RSA_WITH_AES_128_GCM_SHA256), S_RSA, K_DHE, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, + {0,CS(TLS_DHE_RSA_WITH_AES_128_CBC_SHA), S_RSA, K_DHE, C_AES, B_128, M_SHA, 1, 0, 0, }, + {0,CS(TLS_DHE_DSS_WITH_AES_128_CBC_SHA), S_DSA, K_DHE, C_AES, B_128, M_SHA, 1, 0, 0, }, + {0,CS(TLS_RSA_WITH_SEED_CBC_SHA), S_RSA, K_RSA, C_SEED,B_128, M_SHA, 1, 0, 0, }, +@@ -175,6 +181,9 @@ + + #ifdef NSS_ENABLE_ECC + /* ECC cipher suites */ ++{0,CS(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), S_RSA, K_ECDHE, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, ++{0,CS(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), S_ECDSA, K_ECDHE, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, ++ + {0,CS(TLS_ECDH_ECDSA_WITH_NULL_SHA), S_ECDSA, K_ECDH, C_NULL, B_0, M_SHA, 0, 0, 0, }, + {0,CS(TLS_ECDH_ECDSA_WITH_RC4_128_SHA), S_ECDSA, K_ECDH, C_RC4, B_128, M_SHA, 0, 0, 0, }, + {0,CS(TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA), S_ECDSA, K_ECDH, C_3DES, B_3DES, M_SHA, 1, 0, 0, }, +Index: net/third_party/nss/ssl/sslimpl.h +=================================================================== +--- net/third_party/nss/ssl/sslimpl.h (revision 217715) ++++ net/third_party/nss/ssl/sslimpl.h (working copy) +@@ -64,6 +64,7 @@ + #define calg_aes ssl_calg_aes + #define calg_camellia ssl_calg_camellia + #define calg_seed ssl_calg_seed ++#define calg_aes_gcm ssl_calg_aes_gcm + + #define mac_null ssl_mac_null + #define mac_md5 ssl_mac_md5 +@@ -71,6 +72,7 @@ + #define hmac_md5 ssl_hmac_md5 + #define hmac_sha ssl_hmac_sha + #define hmac_sha256 ssl_hmac_sha256 ++#define mac_aead ssl_mac_aead + + #define SET_ERROR_CODE /* reminder */ + #define SEND_ALERT /* reminder */ +@@ -290,9 +292,9 @@ + } ssl3CipherSuiteCfg; + + #ifdef NSS_ENABLE_ECC +-#define ssl_V3_SUITES_IMPLEMENTED 57 ++#define ssl_V3_SUITES_IMPLEMENTED 61 + #else +-#define ssl_V3_SUITES_IMPLEMENTED 35 ++#define ssl_V3_SUITES_IMPLEMENTED 37 + #endif /* NSS_ENABLE_ECC */ + + #define MAX_DTLS_SRTP_CIPHER_SUITES 4 +@@ -440,20 +442,6 @@ + #define GS_DATA 3 + #define GS_PAD 4 + +-typedef SECStatus (*SSLCipher)(void * context, +- unsigned char * out, +- int * outlen, +- int maxout, +- const unsigned char *in, +- int inlen); +-typedef SECStatus (*SSLCompressor)(void * context, +- unsigned char * out, +- int * outlen, +- int maxout, +- const unsigned char *in, +- int inlen); +-typedef SECStatus (*SSLDestroy)(void *context, PRBool freeit); +- + #if defined(NSS_PLATFORM_CLIENT_AUTH) && defined(XP_WIN32) + typedef PCERT_KEY_CONTEXT PlatformKey; + #elif defined(NSS_PLATFORM_CLIENT_AUTH) && defined(XP_MACOSX) +@@ -485,11 +473,12 @@ + cipher_camellia_128, + cipher_camellia_256, + cipher_seed, ++ cipher_aes_128_gcm, + cipher_missing /* reserved for no such supported cipher */ + /* This enum must match ssl3_cipherName[] in ssl3con.c. */ + } SSL3BulkCipher; + +-typedef enum { type_stream, type_block } CipherType; ++typedef enum { type_stream, type_block, type_aead } CipherType; + + #define MAX_IV_LENGTH 24 + +@@ -531,6 +520,30 @@ + PRUint64 cipher_context[MAX_CIPHER_CONTEXT_LLONGS]; + } ssl3KeyMaterial; + ++typedef SECStatus (*SSLCipher)(void * context, ++ unsigned char * out, ++ int * outlen, ++ int maxout, ++ const unsigned char *in, ++ int inlen); ++typedef SECStatus (*SSLAEADCipher)( ++ ssl3KeyMaterial * keys, ++ PRBool doDecrypt, ++ unsigned char * out, ++ int * outlen, ++ int maxout, ++ const unsigned char *in, ++ int inlen, ++ const unsigned char *additionalData, ++ int additionalDataLen); ++typedef SECStatus (*SSLCompressor)(void * context, ++ unsigned char * out, ++ int * outlen, ++ int maxout, ++ const unsigned char *in, ++ int inlen); ++typedef SECStatus (*SSLDestroy)(void *context, PRBool freeit); ++ + /* The DTLS anti-replay window. Defined here because we need it in + * the cipher spec. Note that this is a ring buffer but left and + * right represent the true window, with modular arithmetic used to +@@ -557,6 +570,7 @@ + int mac_size; + SSLCipher encode; + SSLCipher decode; ++ SSLAEADCipher aead; + SSLDestroy destroy; + void * encodeContext; + void * decodeContext; +@@ -706,8 +720,6 @@ + PRBool tls_keygen; + } ssl3KEADef; + +-typedef enum { kg_null, kg_strong, kg_export } SSL3KeyGenMode; +- + /* + ** There are tables of these, all const. + */ +@@ -719,7 +731,8 @@ + CipherType type; + int iv_size; + int block_size; +- SSL3KeyGenMode keygen_mode; ++ int tag_size; /* authentication tag size for AEAD ciphers. */ ++ int explicit_nonce_size; /* for AEAD ciphers. */ + }; + + /* +Index: net/third_party/nss/ssl/ssl3ecc.c +=================================================================== +--- net/third_party/nss/ssl/ssl3ecc.c (revision 217715) ++++ net/third_party/nss/ssl/ssl3ecc.c (working copy) +@@ -911,6 +911,7 @@ + TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, ++ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_NULL_SHA, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, +@@ -921,6 +922,7 @@ + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, ++ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_NULL_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA, +@@ -932,12 +934,14 @@ + TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, ++ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_NULL_SHA, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, ++ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_NULL_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA, +Index: net/third_party/nss/ssl/sslsock.c +=================================================================== +--- net/third_party/nss/ssl/sslsock.c (revision 217715) ++++ net/third_party/nss/ssl/sslsock.c (working copy) +@@ -67,8 +67,10 @@ + { TLS_DHE_DSS_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_DHE_RSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, ++ { TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_RSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, ++ { TLS_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_DHE_DSS_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_DHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, +@@ -94,6 +96,7 @@ + { TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, ++ { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDH_RSA_WITH_NULL_SHA, SSL_ALLOWED, SSL_ALLOWED }, + { TLS_ECDH_RSA_WITH_RC4_128_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, +@@ -105,6 +108,7 @@ + { TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, ++ { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + #endif /* NSS_ENABLE_ECC */ + { 0, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED } +Index: net/third_party/nss/ssl/ssl3con.c +=================================================================== +--- net/third_party/nss/ssl/ssl3con.c (revision 217715) ++++ net/third_party/nss/ssl/ssl3con.c (working copy) +@@ -78,6 +78,13 @@ + static SECStatus Null_Cipher(void *ctx, unsigned char *output, int *outputLen, + int maxOutputLen, const unsigned char *input, + int inputLen); ++#ifndef NO_PKCS11_BYPASS ++static SECStatus ssl3_AESGCMBypass(ssl3KeyMaterial *keys, PRBool doDecrypt, ++ unsigned char *out, int *outlen, int maxout, ++ const unsigned char *in, int inlen, ++ const unsigned char *additionalData, ++ int additionalDataLen); ++#endif + + #define MAX_SEND_BUF_LENGTH 32000 /* watch for 16-bit integer overflow */ + #define MIN_SEND_BUF_LENGTH 4000 +@@ -90,6 +97,13 @@ + static ssl3CipherSuiteCfg cipherSuites[ssl_V3_SUITES_IMPLEMENTED] = { + /* cipher_suite policy enabled is_present*/ + #ifdef NSS_ENABLE_ECC ++ { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, ++ { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, ++#endif /* NSS_ENABLE_ECC */ ++ { TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_TRUE,PR_FALSE}, ++ { TLS_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_TRUE,PR_FALSE}, ++ ++#ifdef NSS_ENABLE_ECC + { TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + { TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + #endif /* NSS_ENABLE_ECC */ +@@ -233,23 +247,30 @@ + + /* indexed by SSL3BulkCipher */ + static const ssl3BulkCipherDef bulk_cipher_defs[] = { +- /* cipher calg keySz secretSz type ivSz BlkSz keygen */ +- {cipher_null, calg_null, 0, 0, type_stream, 0, 0, kg_null}, +- {cipher_rc4, calg_rc4, 16, 16, type_stream, 0, 0, kg_strong}, +- {cipher_rc4_40, calg_rc4, 16, 5, type_stream, 0, 0, kg_export}, +- {cipher_rc4_56, calg_rc4, 16, 7, type_stream, 0, 0, kg_export}, +- {cipher_rc2, calg_rc2, 16, 16, type_block, 8, 8, kg_strong}, +- {cipher_rc2_40, calg_rc2, 16, 5, type_block, 8, 8, kg_export}, +- {cipher_des, calg_des, 8, 8, type_block, 8, 8, kg_strong}, +- {cipher_3des, calg_3des, 24, 24, type_block, 8, 8, kg_strong}, +- {cipher_des40, calg_des, 8, 5, type_block, 8, 8, kg_export}, +- {cipher_idea, calg_idea, 16, 16, type_block, 8, 8, kg_strong}, +- {cipher_aes_128, calg_aes, 16, 16, type_block, 16,16, kg_strong}, +- {cipher_aes_256, calg_aes, 32, 32, type_block, 16,16, kg_strong}, +- {cipher_camellia_128, calg_camellia,16, 16, type_block, 16,16, kg_strong}, +- {cipher_camellia_256, calg_camellia,32, 32, type_block, 16,16, kg_strong}, +- {cipher_seed, calg_seed, 16, 16, type_block, 16,16, kg_strong}, +- {cipher_missing, calg_null, 0, 0, type_stream, 0, 0, kg_null}, ++ /* |--------- Lengths --------| */ ++ /* cipher calg k s type i b t n */ ++ /* e e v l a o */ ++ /* y c | o g n */ ++ /* | r | c | c */ ++ /* | e | k | e */ ++ /* | t | | | | */ ++ {cipher_null, calg_null, 0, 0, type_stream, 0, 0, 0, 0}, ++ {cipher_rc4, calg_rc4, 16,16, type_stream, 0, 0, 0, 0}, ++ {cipher_rc4_40, calg_rc4, 16, 5, type_stream, 0, 0, 0, 0}, ++ {cipher_rc4_56, calg_rc4, 16, 7, type_stream, 0, 0, 0, 0}, ++ {cipher_rc2, calg_rc2, 16,16, type_block, 8, 8, 0, 0}, ++ {cipher_rc2_40, calg_rc2, 16, 5, type_block, 8, 8, 0, 0}, ++ {cipher_des, calg_des, 8, 8, type_block, 8, 8, 0, 0}, ++ {cipher_3des, calg_3des, 24,24, type_block, 8, 8, 0, 0}, ++ {cipher_des40, calg_des, 8, 5, type_block, 8, 8, 0, 0}, ++ {cipher_idea, calg_idea, 16,16, type_block, 8, 8, 0, 0}, ++ {cipher_aes_128, calg_aes, 16,16, type_block, 16,16, 0, 0}, ++ {cipher_aes_256, calg_aes, 32,32, type_block, 16,16, 0, 0}, ++ {cipher_camellia_128, calg_camellia, 16,16, type_block, 16,16, 0, 0}, ++ {cipher_camellia_256, calg_camellia, 32,32, type_block, 16,16, 0, 0}, ++ {cipher_seed, calg_seed, 16,16, type_block, 16,16, 0, 0}, ++ {cipher_aes_128_gcm, calg_aes_gcm, 16,16, type_aead, 4, 0,16, 8}, ++ {cipher_missing, calg_null, 0, 0, type_stream, 0, 0, 0, 0}, + }; + + static const ssl3KEADef kea_defs[] = +@@ -371,6 +392,11 @@ + {SSL_RSA_FIPS_WITH_3DES_EDE_CBC_SHA, cipher_3des, mac_sha, kea_rsa_fips}, + {SSL_RSA_FIPS_WITH_DES_CBC_SHA, cipher_des, mac_sha, kea_rsa_fips}, + ++ {TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_dhe_rsa}, ++ {TLS_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_rsa}, ++ {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_ecdhe_rsa}, ++ {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_ecdhe_ecdsa}, ++ + #ifdef NSS_ENABLE_ECC + {TLS_ECDH_ECDSA_WITH_NULL_SHA, cipher_null, mac_sha, kea_ecdh_ecdsa}, + {TLS_ECDH_ECDSA_WITH_RC4_128_SHA, cipher_rc4, mac_sha, kea_ecdh_ecdsa}, +@@ -434,25 +460,29 @@ + { calg_aes , CKM_AES_CBC }, + { calg_camellia , CKM_CAMELLIA_CBC }, + { calg_seed , CKM_SEED_CBC }, ++ { calg_aes_gcm , CKM_AES_GCM }, + /* { calg_init , (CK_MECHANISM_TYPE)0x7fffffffL } */ + }; + +-#define mmech_null (CK_MECHANISM_TYPE)0x80000000L ++#define mmech_invalid (CK_MECHANISM_TYPE)0x80000000L + #define mmech_md5 CKM_SSL3_MD5_MAC + #define mmech_sha CKM_SSL3_SHA1_MAC + #define mmech_md5_hmac CKM_MD5_HMAC + #define mmech_sha_hmac CKM_SHA_1_HMAC + #define mmech_sha256_hmac CKM_SHA256_HMAC ++#define mmech_sha384_hmac CKM_SHA384_HMAC ++#define mmech_sha512_hmac CKM_SHA512_HMAC + + static const ssl3MACDef mac_defs[] = { /* indexed by SSL3MACAlgorithm */ + /* pad_size is only used for SSL 3.0 MAC. See RFC 6101 Sec. 5.2.3.1. */ + /* mac mmech pad_size mac_size */ +- { mac_null, mmech_null, 0, 0 }, ++ { mac_null, mmech_invalid, 0, 0 }, + { mac_md5, mmech_md5, 48, MD5_LENGTH }, + { mac_sha, mmech_sha, 40, SHA1_LENGTH}, + {hmac_md5, mmech_md5_hmac, 0, MD5_LENGTH }, + {hmac_sha, mmech_sha_hmac, 0, SHA1_LENGTH}, + {hmac_sha256, mmech_sha256_hmac, 0, SHA256_LENGTH}, ++ { mac_aead, mmech_invalid, 0, 0 }, + }; + + /* indexed by SSL3BulkCipher */ +@@ -472,6 +502,7 @@ + "Camellia-128", + "Camellia-256", + "SEED-CBC", ++ "AES-128-GCM", + "missing" + }; + +@@ -598,9 +629,13 @@ + case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: + case TLS_RSA_WITH_AES_256_CBC_SHA256: + case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: ++ case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: ++ case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256: ++ case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256: + case TLS_RSA_WITH_AES_128_CBC_SHA256: ++ case TLS_RSA_WITH_AES_128_GCM_SHA256: + case TLS_RSA_WITH_NULL_SHA256: + return version >= SSL_LIBRARY_VERSION_TLS_1_2; + default: +@@ -1360,7 +1395,7 @@ + cipher = suite_def->bulk_cipher_alg; + kea = suite_def->key_exchange_alg; + mac = suite_def->mac_alg; +- if (mac <= ssl_mac_sha && isTLS) ++ if (mac <= ssl_mac_sha && mac != ssl_mac_null && isTLS) + mac += 2; + + ss->ssl3.hs.suite_def = suite_def; +@@ -1554,7 +1589,6 @@ + unsigned int optArg2 = 0; + PRBool server_encrypts = ss->sec.isServer; + SSLCipherAlgorithm calg; +- SSLCompressionMethod compression_method; + SECStatus rv; + + PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); +@@ -1565,8 +1599,18 @@ + cipher_def = pwSpec->cipher_def; + + calg = cipher_def->calg; +- compression_method = pwSpec->compression_method; + ++ if (calg == calg_aes_gcm) { ++ pwSpec->encode = NULL; ++ pwSpec->decode = NULL; ++ pwSpec->destroy = NULL; ++ pwSpec->encodeContext = NULL; ++ pwSpec->decodeContext = NULL; ++ pwSpec->aead = ssl3_AESGCMBypass; ++ ssl3_InitCompressionContext(pwSpec); ++ return SECSuccess; ++ } ++ + serverContext = pwSpec->server.cipher_context; + clientContext = pwSpec->client.cipher_context; + +@@ -1721,6 +1765,195 @@ + return param; + } + ++/* ssl3_BuildRecordPseudoHeader writes the SSL/TLS pseudo-header (the data ++ * which is included in the MAC or AEAD additional data) to |out| and returns ++ * its length. See https://tools.ietf.org/html/rfc5246#section-6.2.3.3 for the ++ * definition of the AEAD additional data. ++ * ++ * TLS pseudo-header includes the record's version field, SSL's doesn't. Which ++ * pseudo-header defintiion to use should be decided based on the version of ++ * the protocol that was negotiated when the cipher spec became current, NOT ++ * based on the version value in the record itself, and the decision is passed ++ * to this function as the |includesVersion| argument. But, the |version| ++ * argument should be the record's version value. ++ */ ++static unsigned int ++ssl3_BuildRecordPseudoHeader(unsigned char *out, ++ SSL3SequenceNumber seq_num, ++ SSL3ContentType type, ++ PRBool includesVersion, ++ SSL3ProtocolVersion version, ++ PRBool isDTLS, ++ int length) ++{ ++ out[0] = (unsigned char)(seq_num.high >> 24); ++ out[1] = (unsigned char)(seq_num.high >> 16); ++ out[2] = (unsigned char)(seq_num.high >> 8); ++ out[3] = (unsigned char)(seq_num.high >> 0); ++ out[4] = (unsigned char)(seq_num.low >> 24); ++ out[5] = (unsigned char)(seq_num.low >> 16); ++ out[6] = (unsigned char)(seq_num.low >> 8); ++ out[7] = (unsigned char)(seq_num.low >> 0); ++ out[8] = type; ++ ++ /* SSL3 MAC doesn't include the record's version field. */ ++ if (!includesVersion) { ++ out[9] = MSB(length); ++ out[10] = LSB(length); ++ return 11; ++ } ++ ++ /* TLS MAC and AEAD additional data include version. */ ++ if (isDTLS) { ++ SSL3ProtocolVersion dtls_version; ++ ++ dtls_version = dtls_TLSVersionToDTLSVersion(version); ++ out[9] = MSB(dtls_version); ++ out[10] = LSB(dtls_version); ++ } else { ++ out[9] = MSB(version); ++ out[10] = LSB(version); ++ } ++ out[11] = MSB(length); ++ out[12] = LSB(length); ++ return 13; ++} ++ ++static SECStatus ++ssl3_AESGCM(ssl3KeyMaterial *keys, ++ PRBool doDecrypt, ++ unsigned char *out, ++ int *outlen, ++ int maxout, ++ const unsigned char *in, ++ int inlen, ++ const unsigned char *additionalData, ++ int additionalDataLen) ++{ ++ SECItem param; ++ SECStatus rv = SECFailure; ++ unsigned char nonce[12]; ++ unsigned int uOutLen; ++ CK_GCM_PARAMS gcmParams; ++ ++ static const int tagSize = 16; ++ static const int explicitNonceLen = 8; ++ ++ /* See https://tools.ietf.org/html/rfc5288#section-3 for details of how the ++ * nonce is formed. */ ++ memcpy(nonce, keys->write_iv, 4); ++ if (doDecrypt) { ++ memcpy(nonce + 4, in, explicitNonceLen); ++ in += explicitNonceLen; ++ inlen -= explicitNonceLen; ++ *outlen = 0; ++ } else { ++ if (maxout < explicitNonceLen) { ++ PORT_SetError(SEC_ERROR_INPUT_LEN); ++ return SECFailure; ++ } ++ /* Use the 64-bit sequence number as the explicit nonce. */ ++ memcpy(nonce + 4, additionalData, explicitNonceLen); ++ memcpy(out, additionalData, explicitNonceLen); ++ out += explicitNonceLen; ++ maxout -= explicitNonceLen; ++ *outlen = explicitNonceLen; ++ } ++ ++ param.type = siBuffer; ++ param.data = (unsigned char *) &gcmParams; ++ param.len = sizeof(gcmParams); ++ gcmParams.pIv = nonce; ++ gcmParams.ulIvLen = sizeof(nonce); ++ gcmParams.pAAD = (unsigned char *)additionalData; /* const cast */ ++ gcmParams.ulAADLen = additionalDataLen; ++ gcmParams.ulTagBits = tagSize * 8; ++ ++ if (doDecrypt) { ++ rv = PK11_Decrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, ++ maxout, in, inlen); ++ } else { ++ rv = PK11_Encrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, ++ maxout, in, inlen); ++ } ++ *outlen += (int) uOutLen; ++ ++ return rv; ++} ++ ++#ifndef NO_PKCS11_BYPASS ++static SECStatus ++ssl3_AESGCMBypass(ssl3KeyMaterial *keys, ++ PRBool doDecrypt, ++ unsigned char *out, ++ int *outlen, ++ int maxout, ++ const unsigned char *in, ++ int inlen, ++ const unsigned char *additionalData, ++ int additionalDataLen) ++{ ++ SECStatus rv = SECFailure; ++ unsigned char nonce[12]; ++ unsigned int uOutLen; ++ AESContext *cx; ++ CK_GCM_PARAMS gcmParams; ++ ++ static const int tagSize = 16; ++ static const int explicitNonceLen = 8; ++ ++ /* See https://tools.ietf.org/html/rfc5288#section-3 for details of how the ++ * nonce is formed. */ ++ PORT_Assert(keys->write_iv_item.len == 4); ++ if (keys->write_iv_item.len != 4) { ++ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); ++ return SECFailure; ++ } ++ memcpy(nonce, keys->write_iv_item.data, 4); ++ if (doDecrypt) { ++ memcpy(nonce + 4, in, explicitNonceLen); ++ in += explicitNonceLen; ++ inlen -= explicitNonceLen; ++ *outlen = 0; ++ } else { ++ if (maxout < explicitNonceLen) { ++ PORT_SetError(SEC_ERROR_INPUT_LEN); ++ return SECFailure; ++ } ++ /* Use the 64-bit sequence number as the explicit nonce. */ ++ memcpy(nonce + 4, additionalData, explicitNonceLen); ++ memcpy(out, additionalData, explicitNonceLen); ++ out += explicitNonceLen; ++ maxout -= explicitNonceLen; ++ *outlen = explicitNonceLen; ++ } ++ ++ gcmParams.pIv = nonce; ++ gcmParams.ulIvLen = sizeof(nonce); ++ gcmParams.pAAD = (unsigned char *)additionalData; /* const cast */ ++ gcmParams.ulAADLen = additionalDataLen; ++ gcmParams.ulTagBits = tagSize * 8; ++ ++ cx = (AESContext *)keys->cipher_context; ++ rv = AES_InitContext(cx, keys->write_key_item.data, ++ keys->write_key_item.len, ++ (unsigned char *)&gcmParams, NSS_AES_GCM, !doDecrypt, ++ AES_BLOCK_SIZE); ++ if (rv != SECSuccess) { ++ return rv; ++ } ++ if (doDecrypt) { ++ rv = AES_Decrypt(cx, out, &uOutLen, maxout, in, inlen); ++ } else { ++ rv = AES_Encrypt(cx, out, &uOutLen, maxout, in, inlen); ++ } ++ AES_DestroyContext(cx, PR_FALSE); ++ *outlen += (int) uOutLen; ++ ++ return rv; ++} ++#endif ++ + /* Initialize encryption and MAC contexts for pending spec. + * Master Secret already is derived. + * Caller holds Spec write lock. +@@ -1748,14 +1981,27 @@ + pwSpec = ss->ssl3.pwSpec; + cipher_def = pwSpec->cipher_def; + macLength = pwSpec->mac_size; ++ calg = cipher_def->calg; ++ PORT_Assert(alg2Mech[calg].calg == calg); + ++ pwSpec->client.write_mac_context = NULL; ++ pwSpec->server.write_mac_context = NULL; ++ ++ if (calg == calg_aes_gcm) { ++ pwSpec->encode = NULL; ++ pwSpec->decode = NULL; ++ pwSpec->destroy = NULL; ++ pwSpec->encodeContext = NULL; ++ pwSpec->decodeContext = NULL; ++ pwSpec->aead = ssl3_AESGCM; ++ return SECSuccess; ++ } ++ + /* + ** Now setup the MAC contexts, + ** crypto contexts are setup below. + */ + +- pwSpec->client.write_mac_context = NULL; +- pwSpec->server.write_mac_context = NULL; + mac_mech = pwSpec->mac_def->mmech; + mac_param.data = (unsigned char *)&macLength; + mac_param.len = sizeof(macLength); +@@ -1778,9 +2024,6 @@ + ** Now setup the crypto contexts. + */ + +- calg = cipher_def->calg; +- PORT_Assert(alg2Mech[calg].calg == calg); +- + if (calg == calg_null) { + pwSpec->encode = Null_Cipher; + pwSpec->decode = Null_Cipher; +@@ -1988,10 +2231,8 @@ + ssl3_ComputeRecordMAC( + ssl3CipherSpec * spec, + PRBool useServerMacKey, +- PRBool isDTLS, +- SSL3ContentType type, +- SSL3ProtocolVersion version, +- SSL3SequenceNumber seq_num, ++ const unsigned char *header, ++ unsigned int headerLen, + const SSL3Opaque * input, + int inputLength, + unsigned char * outbuf, +@@ -1999,56 +2240,8 @@ + { + const ssl3MACDef * mac_def; + SECStatus rv; +-#ifndef NO_PKCS11_BYPASS +- PRBool isTLS; +-#endif +- unsigned int tempLen; +- unsigned char temp[MAX_MAC_LENGTH]; + +- temp[0] = (unsigned char)(seq_num.high >> 24); +- temp[1] = (unsigned char)(seq_num.high >> 16); +- temp[2] = (unsigned char)(seq_num.high >> 8); +- temp[3] = (unsigned char)(seq_num.high >> 0); +- temp[4] = (unsigned char)(seq_num.low >> 24); +- temp[5] = (unsigned char)(seq_num.low >> 16); +- temp[6] = (unsigned char)(seq_num.low >> 8); +- temp[7] = (unsigned char)(seq_num.low >> 0); +- temp[8] = type; +- +- /* TLS MAC includes the record's version field, SSL's doesn't. +- ** We decide which MAC defintiion to use based on the version of +- ** the protocol that was negotiated when the spec became current, +- ** NOT based on the version value in the record itself. +- ** But, we use the record'v version value in the computation. +- */ +- if (spec->version <= SSL_LIBRARY_VERSION_3_0) { +- temp[9] = MSB(inputLength); +- temp[10] = LSB(inputLength); +- tempLen = 11; +-#ifndef NO_PKCS11_BYPASS +- isTLS = PR_FALSE; +-#endif +- } else { +- /* New TLS hash includes version. */ +- if (isDTLS) { +- SSL3ProtocolVersion dtls_version; +- +- dtls_version = dtls_TLSVersionToDTLSVersion(version); +- temp[9] = MSB(dtls_version); +- temp[10] = LSB(dtls_version); +- } else { +- temp[9] = MSB(version); +- temp[10] = LSB(version); +- } +- temp[11] = MSB(inputLength); +- temp[12] = LSB(inputLength); +- tempLen = 13; +-#ifndef NO_PKCS11_BYPASS +- isTLS = PR_TRUE; +-#endif +- } +- +- PRINT_BUF(95, (NULL, "frag hash1: temp", temp, tempLen)); ++ PRINT_BUF(95, (NULL, "frag hash1: header", header, headerLen)); + PRINT_BUF(95, (NULL, "frag hash1: input", input, inputLength)); + + mac_def = spec->mac_def; +@@ -2093,7 +2286,10 @@ + return SECFailure; + } + +- if (!isTLS) { ++ if (spec->version <= SSL_LIBRARY_VERSION_3_0) { ++ unsigned int tempLen; ++ unsigned char temp[MAX_MAC_LENGTH]; ++ + /* compute "inner" part of SSL3 MAC */ + hashObj->begin(write_mac_context); + if (useServerMacKey) +@@ -2105,7 +2301,7 @@ + spec->client.write_mac_key_item.data, + spec->client.write_mac_key_item.len); + hashObj->update(write_mac_context, mac_pad_1, pad_bytes); +- hashObj->update(write_mac_context, temp, tempLen); ++ hashObj->update(write_mac_context, header, headerLen); + hashObj->update(write_mac_context, input, inputLength); + hashObj->end(write_mac_context, temp, &tempLen, sizeof temp); + +@@ -2136,7 +2332,7 @@ + } + if (rv == SECSuccess) { + HMAC_Begin(cx); +- HMAC_Update(cx, temp, tempLen); ++ HMAC_Update(cx, header, headerLen); + HMAC_Update(cx, input, inputLength); + rv = HMAC_Finish(cx, outbuf, outLength, spec->mac_size); + HMAC_Destroy(cx, PR_FALSE); +@@ -2150,7 +2346,7 @@ + (useServerMacKey ? spec->server.write_mac_context + : spec->client.write_mac_context); + rv = PK11_DigestBegin(mac_context); +- rv |= PK11_DigestOp(mac_context, temp, tempLen); ++ rv |= PK11_DigestOp(mac_context, header, headerLen); + rv |= PK11_DigestOp(mac_context, input, inputLength); + rv |= PK11_DigestFinal(mac_context, outbuf, outLength, spec->mac_size); + } +@@ -2190,10 +2386,8 @@ + ssl3_ComputeRecordMACConstantTime( + ssl3CipherSpec * spec, + PRBool useServerMacKey, +- PRBool isDTLS, +- SSL3ContentType type, +- SSL3ProtocolVersion version, +- SSL3SequenceNumber seq_num, ++ const unsigned char *header, ++ unsigned int headerLen, + const SSL3Opaque * input, + int inputLen, + int originalLen, +@@ -2205,9 +2399,7 @@ + PK11Context * mac_context; + SECItem param; + SECStatus rv; +- unsigned char header[13]; + PK11SymKey * key; +- int recordLength; + + PORT_Assert(inputLen >= spec->mac_size); + PORT_Assert(originalLen >= inputLen); +@@ -2223,42 +2415,15 @@ + return SECSuccess; + } + +- header[0] = (unsigned char)(seq_num.high >> 24); +- header[1] = (unsigned char)(seq_num.high >> 16); +- header[2] = (unsigned char)(seq_num.high >> 8); +- header[3] = (unsigned char)(seq_num.high >> 0); +- header[4] = (unsigned char)(seq_num.low >> 24); +- header[5] = (unsigned char)(seq_num.low >> 16); +- header[6] = (unsigned char)(seq_num.low >> 8); +- header[7] = (unsigned char)(seq_num.low >> 0); +- header[8] = type; +- + macType = CKM_NSS_HMAC_CONSTANT_TIME; +- recordLength = inputLen - spec->mac_size; + if (spec->version <= SSL_LIBRARY_VERSION_3_0) { + macType = CKM_NSS_SSL3_MAC_CONSTANT_TIME; +- header[9] = recordLength >> 8; +- header[10] = recordLength; +- params.ulHeaderLen = 11; +- } else { +- if (isDTLS) { +- SSL3ProtocolVersion dtls_version; +- +- dtls_version = dtls_TLSVersionToDTLSVersion(version); +- header[9] = dtls_version >> 8; +- header[10] = dtls_version; +- } else { +- header[9] = version >> 8; +- header[10] = version; +- } +- header[11] = recordLength >> 8; +- header[12] = recordLength; +- params.ulHeaderLen = 13; + } + + params.macAlg = spec->mac_def->mmech; + params.ulBodyTotalLen = originalLen; +- params.pHeader = header; ++ params.pHeader = (unsigned char *) header; /* const cast */ ++ params.ulHeaderLen = headerLen; + + param.data = (unsigned char*) ¶ms; + param.len = sizeof(params); +@@ -2291,9 +2456,8 @@ + /* ssl3_ComputeRecordMAC expects the MAC to have been removed from the + * length already. */ + inputLen -= spec->mac_size; +- return ssl3_ComputeRecordMAC(spec, useServerMacKey, isDTLS, type, +- version, seq_num, input, inputLen, +- outbuf, outLen); ++ return ssl3_ComputeRecordMAC(spec, useServerMacKey, header, headerLen, ++ input, inputLen, outbuf, outLen); + } + + static PRBool +@@ -2345,6 +2509,8 @@ + PRUint16 headerLen; + int ivLen = 0; + int cipherBytes = 0; ++ unsigned char pseudoHeader[13]; ++ unsigned int pseudoHeaderLen; + + cipher_def = cwSpec->cipher_def; + headerLen = isDTLS ? DTLS_RECORD_HEADER_LENGTH : SSL3_RECORD_HEADER_LENGTH; +@@ -2390,86 +2556,117 @@ + contentLen = outlen; + } + +- /* +- * Add the MAC +- */ +- rv = ssl3_ComputeRecordMAC( cwSpec, isServer, isDTLS, +- type, cwSpec->version, cwSpec->write_seq_num, pIn, contentLen, +- wrBuf->buf + headerLen + ivLen + contentLen, &macLen); +- if (rv != SECSuccess) { +- ssl_MapLowLevelError(SSL_ERROR_MAC_COMPUTATION_FAILURE); +- return SECFailure; +- } +- p1Len = contentLen; +- p2Len = macLen; +- fragLen = contentLen + macLen; /* needs to be encrypted */ +- PORT_Assert(fragLen <= MAX_FRAGMENT_LENGTH + 1024); ++ pseudoHeaderLen = ssl3_BuildRecordPseudoHeader( ++ pseudoHeader, cwSpec->write_seq_num, type, ++ cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_0, cwSpec->version, ++ isDTLS, contentLen); ++ PORT_Assert(pseudoHeaderLen <= sizeof(pseudoHeader)); ++ if (cipher_def->type == type_aead) { ++ const int nonceLen = cipher_def->explicit_nonce_size; ++ const int tagLen = cipher_def->tag_size; + +- /* +- * Pad the text (if we're doing a block cipher) +- * then Encrypt it +- */ +- if (cipher_def->type == type_block) { +- unsigned char * pBuf; +- int padding_length; +- int i; ++ if (headerLen + nonceLen + contentLen + tagLen > wrBuf->space) { ++ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); ++ return SECFailure; ++ } + +- oddLen = contentLen % cipher_def->block_size; +- /* Assume blockSize is a power of two */ +- padding_length = cipher_def->block_size - 1 - +- ((fragLen) & (cipher_def->block_size - 1)); +- fragLen += padding_length + 1; +- PORT_Assert((fragLen % cipher_def->block_size) == 0); +- +- /* Pad according to TLS rules (also acceptable to SSL3). */ +- pBuf = &wrBuf->buf[headerLen + ivLen + fragLen - 1]; +- for (i = padding_length + 1; i > 0; --i) { +- *pBuf-- = padding_length; ++ cipherBytes = contentLen; ++ rv = cwSpec->aead( ++ isServer ? &cwSpec->server : &cwSpec->client, ++ PR_FALSE, /* do encrypt */ ++ wrBuf->buf + headerLen, /* output */ ++ &cipherBytes, /* out len */ ++ wrBuf->space - headerLen, /* max out */ ++ pIn, contentLen, /* input */ ++ pseudoHeader, pseudoHeaderLen); ++ if (rv != SECSuccess) { ++ PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); ++ return SECFailure; + } +- /* now, if contentLen is not a multiple of block size, fix it */ +- p2Len = fragLen - p1Len; +- } +- if (p1Len < 256) { +- oddLen = p1Len; +- p1Len = 0; + } else { +- p1Len -= oddLen; +- } +- if (oddLen) { +- p2Len += oddLen; +- PORT_Assert( (cipher_def->block_size < 2) || \ +- (p2Len % cipher_def->block_size) == 0); +- memmove(wrBuf->buf + headerLen + ivLen + p1Len, pIn + p1Len, oddLen); +- } +- if (p1Len > 0) { +- int cipherBytesPart1 = -1; +- rv = cwSpec->encode( cwSpec->encodeContext, +- wrBuf->buf + headerLen + ivLen, /* output */ +- &cipherBytesPart1, /* actual outlen */ +- p1Len, /* max outlen */ +- pIn, p1Len); /* input, and inputlen */ +- PORT_Assert(rv == SECSuccess && cipherBytesPart1 == (int) p1Len); +- if (rv != SECSuccess || cipherBytesPart1 != (int) p1Len) { +- PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); ++ /* ++ * Add the MAC ++ */ ++ rv = ssl3_ComputeRecordMAC(cwSpec, isServer, ++ pseudoHeader, pseudoHeaderLen, pIn, contentLen, ++ wrBuf->buf + headerLen + ivLen + contentLen, &macLen); ++ if (rv != SECSuccess) { ++ ssl_MapLowLevelError(SSL_ERROR_MAC_COMPUTATION_FAILURE); + return SECFailure; + } +- cipherBytes += cipherBytesPart1; ++ p1Len = contentLen; ++ p2Len = macLen; ++ fragLen = contentLen + macLen; /* needs to be encrypted */ ++ PORT_Assert(fragLen <= MAX_FRAGMENT_LENGTH + 1024); ++ ++ /* ++ * Pad the text (if we're doing a block cipher) ++ * then Encrypt it ++ */ ++ if (cipher_def->type == type_block) { ++ unsigned char * pBuf; ++ int padding_length; ++ int i; ++ ++ oddLen = contentLen % cipher_def->block_size; ++ /* Assume blockSize is a power of two */ ++ padding_length = cipher_def->block_size - 1 - ++ ((fragLen) & (cipher_def->block_size - 1)); ++ fragLen += padding_length + 1; ++ PORT_Assert((fragLen % cipher_def->block_size) == 0); ++ ++ /* Pad according to TLS rules (also acceptable to SSL3). */ ++ pBuf = &wrBuf->buf[headerLen + ivLen + fragLen - 1]; ++ for (i = padding_length + 1; i > 0; --i) { ++ *pBuf-- = padding_length; ++ } ++ /* now, if contentLen is not a multiple of block size, fix it */ ++ p2Len = fragLen - p1Len; ++ } ++ if (p1Len < 256) { ++ oddLen = p1Len; ++ p1Len = 0; ++ } else { ++ p1Len -= oddLen; ++ } ++ if (oddLen) { ++ p2Len += oddLen; ++ PORT_Assert( (cipher_def->block_size < 2) || \ ++ (p2Len % cipher_def->block_size) == 0); ++ memmove(wrBuf->buf + headerLen + ivLen + p1Len, pIn + p1Len, ++ oddLen); ++ } ++ if (p1Len > 0) { ++ int cipherBytesPart1 = -1; ++ rv = cwSpec->encode( cwSpec->encodeContext, ++ wrBuf->buf + headerLen + ivLen, /* output */ ++ &cipherBytesPart1, /* actual outlen */ ++ p1Len, /* max outlen */ ++ pIn, p1Len); /* input, and inputlen */ ++ PORT_Assert(rv == SECSuccess && cipherBytesPart1 == (int) p1Len); ++ if (rv != SECSuccess || cipherBytesPart1 != (int) p1Len) { ++ PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); ++ return SECFailure; ++ } ++ cipherBytes += cipherBytesPart1; ++ } ++ if (p2Len > 0) { ++ int cipherBytesPart2 = -1; ++ rv = cwSpec->encode( cwSpec->encodeContext, ++ wrBuf->buf + headerLen + ivLen + p1Len, ++ &cipherBytesPart2, /* output and actual outLen */ ++ p2Len, /* max outlen */ ++ wrBuf->buf + headerLen + ivLen + p1Len, ++ p2Len); /* input and inputLen*/ ++ PORT_Assert(rv == SECSuccess && cipherBytesPart2 == (int) p2Len); ++ if (rv != SECSuccess || cipherBytesPart2 != (int) p2Len) { ++ PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); ++ return SECFailure; ++ } ++ cipherBytes += cipherBytesPart2; ++ } + } +- if (p2Len > 0) { +- int cipherBytesPart2 = -1; +- rv = cwSpec->encode( cwSpec->encodeContext, +- wrBuf->buf + headerLen + ivLen + p1Len, +- &cipherBytesPart2, /* output and actual outLen */ +- p2Len, /* max outlen */ +- wrBuf->buf + headerLen + ivLen + p1Len, +- p2Len); /* input and inputLen*/ +- PORT_Assert(rv == SECSuccess && cipherBytesPart2 == (int) p2Len); +- if (rv != SECSuccess || cipherBytesPart2 != (int) p2Len) { +- PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); +- return SECFailure; +- } +- cipherBytes += cipherBytesPart2; +- } ++ + PORT_Assert(cipherBytes <= MAX_FRAGMENT_LENGTH + 1024); + + wrBuf->len = cipherBytes + headerLen; +@@ -3012,9 +3209,6 @@ + static SECStatus + ssl3_IllegalParameter(sslSocket *ss) + { +- PRBool isTLS; +- +- isTLS = (PRBool)(ss->ssl3.pwSpec->version > SSL_LIBRARY_VERSION_3_0); + (void)SSL3_SendAlert(ss, alert_fatal, illegal_parameter); + PORT_SetError(ss->sec.isServer ? SSL_ERROR_BAD_CLIENT + : SSL_ERROR_BAD_SERVER ); +@@ -3538,7 +3732,6 @@ + } + + key_material_params.bIsExport = (CK_BBOOL)(kea_def->is_limited); +- /* was: (CK_BBOOL)(cipher_def->keygen_mode != kg_strong); */ + + key_material_params.RandomInfo.pClientRandom = cr; + key_material_params.RandomInfo.ulClientRandomLen = SSL3_RANDOM_LENGTH; +@@ -9946,7 +10139,6 @@ + static void + ssl3_RecordKeyLog(sslSocket *ss) + { +- sslSessionID *sid; + SECStatus rv; + SECItem *keyData; + char buf[14 /* "CLIENT_RANDOM " */ + +@@ -9958,8 +10150,6 @@ + + PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); + +- sid = ss->sec.ci.sid; +- + if (!ssl_keylog_iob) + return; + +@@ -11095,6 +11285,8 @@ + unsigned int originalLen = 0; + unsigned int good; + unsigned int minLength; ++ unsigned char header[13]; ++ unsigned int headerLen; + + PORT_Assert( ss->opt.noLocks || ssl_HaveRecvBufLock(ss) ); + +@@ -11171,12 +11363,14 @@ + /* With >= TLS 1.1, CBC records have an explicit IV. */ + minLength += cipher_def->iv_size; + } ++ } else if (cipher_def->type == type_aead) { ++ minLength = cipher_def->explicit_nonce_size + cipher_def->tag_size; + } + + /* We can perform this test in variable time because the record's total + * length and the ciphersuite are both public knowledge. */ + if (cText->buf->len < minLength) { +- goto decrypt_loser; ++ goto decrypt_loser; + } + + if (cipher_def->type == type_block && +@@ -11244,78 +11438,104 @@ + return SECFailure; + } + +- if (cipher_def->type == type_block && +- ((cText->buf->len - ivLen) % cipher_def->block_size) != 0) { +- goto decrypt_loser; +- } ++ rType = cText->type; ++ if (cipher_def->type == type_aead) { ++ /* XXX For many AEAD ciphers, the plaintext is shorter than the ++ * ciphertext by a fixed byte count, but it is not true in general. ++ * Each AEAD cipher should provide a function that returns the ++ * plaintext length for a given ciphertext. */ ++ unsigned int decryptedLen = ++ cText->buf->len - cipher_def->explicit_nonce_size - ++ cipher_def->tag_size; ++ headerLen = ssl3_BuildRecordPseudoHeader( ++ header, IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, ++ rType, isTLS, cText->version, IS_DTLS(ss), decryptedLen); ++ PORT_Assert(headerLen <= sizeof(header)); ++ rv = crSpec->aead( ++ ss->sec.isServer ? &crSpec->client : &crSpec->server, ++ PR_TRUE, /* do decrypt */ ++ plaintext->buf, /* out */ ++ (int*) &plaintext->len, /* outlen */ ++ plaintext->space, /* maxout */ ++ cText->buf->buf, /* in */ ++ cText->buf->len, /* inlen */ ++ header, headerLen); ++ if (rv != SECSuccess) { ++ good = 0; ++ } ++ } else { ++ if (cipher_def->type == type_block && ++ ((cText->buf->len - ivLen) % cipher_def->block_size) != 0) { ++ goto decrypt_loser; ++ } + +- /* decrypt from cText buf to plaintext. */ +- rv = crSpec->decode( +- crSpec->decodeContext, plaintext->buf, (int *)&plaintext->len, +- plaintext->space, cText->buf->buf + ivLen, cText->buf->len - ivLen); +- if (rv != SECSuccess) { +- goto decrypt_loser; +- } ++ /* decrypt from cText buf to plaintext. */ ++ rv = crSpec->decode( ++ crSpec->decodeContext, plaintext->buf, (int *)&plaintext->len, ++ plaintext->space, cText->buf->buf + ivLen, cText->buf->len - ivLen); ++ if (rv != SECSuccess) { ++ goto decrypt_loser; ++ } + +- PRINT_BUF(80, (ss, "cleartext:", plaintext->buf, plaintext->len)); ++ PRINT_BUF(80, (ss, "cleartext:", plaintext->buf, plaintext->len)); + +- originalLen = plaintext->len; ++ originalLen = plaintext->len; + +- /* If it's a block cipher, check and strip the padding. */ +- if (cipher_def->type == type_block) { +- const unsigned int blockSize = cipher_def->block_size; +- const unsigned int macSize = crSpec->mac_size; ++ /* If it's a block cipher, check and strip the padding. */ ++ if (cipher_def->type == type_block) { ++ const unsigned int blockSize = cipher_def->block_size; ++ const unsigned int macSize = crSpec->mac_size; + +- if (crSpec->version <= SSL_LIBRARY_VERSION_3_0) { +- good &= SECStatusToMask(ssl_RemoveSSLv3CBCPadding( +- plaintext, blockSize, macSize)); +- } else { +- good &= SECStatusToMask(ssl_RemoveTLSCBCPadding( +- plaintext, macSize)); ++ if (!isTLS) { ++ good &= SECStatusToMask(ssl_RemoveSSLv3CBCPadding( ++ plaintext, blockSize, macSize)); ++ } else { ++ good &= SECStatusToMask(ssl_RemoveTLSCBCPadding( ++ plaintext, macSize)); ++ } + } +- } + +- /* compute the MAC */ +- rType = cText->type; +- if (cipher_def->type == type_block) { +- rv = ssl3_ComputeRecordMACConstantTime( +- crSpec, (PRBool)(!ss->sec.isServer), +- IS_DTLS(ss), rType, cText->version, +- IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, +- plaintext->buf, plaintext->len, originalLen, +- hash, &hashBytes); ++ /* compute the MAC */ ++ headerLen = ssl3_BuildRecordPseudoHeader( ++ header, IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, ++ rType, isTLS, cText->version, IS_DTLS(ss), ++ plaintext->len - crSpec->mac_size); ++ PORT_Assert(headerLen <= sizeof(header)); ++ if (cipher_def->type == type_block) { ++ rv = ssl3_ComputeRecordMACConstantTime( ++ crSpec, (PRBool)(!ss->sec.isServer), header, headerLen, ++ plaintext->buf, plaintext->len, originalLen, ++ hash, &hashBytes); + +- ssl_CBCExtractMAC(plaintext, originalLen, givenHashBuf, +- crSpec->mac_size); +- givenHash = givenHashBuf; ++ ssl_CBCExtractMAC(plaintext, originalLen, givenHashBuf, ++ crSpec->mac_size); ++ givenHash = givenHashBuf; + +- /* plaintext->len will always have enough space to remove the MAC +- * because in ssl_Remove{SSLv3|TLS}CBCPadding we only adjust +- * plaintext->len if the result has enough space for the MAC and we +- * tested the unadjusted size against minLength, above. */ +- plaintext->len -= crSpec->mac_size; +- } else { +- /* This is safe because we checked the minLength above. */ +- plaintext->len -= crSpec->mac_size; ++ /* plaintext->len will always have enough space to remove the MAC ++ * because in ssl_Remove{SSLv3|TLS}CBCPadding we only adjust ++ * plaintext->len if the result has enough space for the MAC and we ++ * tested the unadjusted size against minLength, above. */ ++ plaintext->len -= crSpec->mac_size; ++ } else { ++ /* This is safe because we checked the minLength above. */ ++ plaintext->len -= crSpec->mac_size; + +- rv = ssl3_ComputeRecordMAC( +- crSpec, (PRBool)(!ss->sec.isServer), +- IS_DTLS(ss), rType, cText->version, +- IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, +- plaintext->buf, plaintext->len, +- hash, &hashBytes); ++ rv = ssl3_ComputeRecordMAC( ++ crSpec, (PRBool)(!ss->sec.isServer), header, headerLen, ++ plaintext->buf, plaintext->len, hash, &hashBytes); + +- /* We can read the MAC directly from the record because its location is +- * public when a stream cipher is used. */ +- givenHash = plaintext->buf + plaintext->len; +- } ++ /* We can read the MAC directly from the record because its location ++ * is public when a stream cipher is used. */ ++ givenHash = plaintext->buf + plaintext->len; ++ } + +- good &= SECStatusToMask(rv); ++ good &= SECStatusToMask(rv); + +- if (hashBytes != (unsigned)crSpec->mac_size || +- NSS_SecureMemcmp(givenHash, hash, crSpec->mac_size) != 0) { +- /* We're allowed to leak whether or not the MAC check was correct */ +- good = 0; ++ if (hashBytes != (unsigned)crSpec->mac_size || ++ NSS_SecureMemcmp(givenHash, hash, crSpec->mac_size) != 0) { ++ /* We're allowed to leak whether or not the MAC check was correct */ ++ good = 0; ++ } + } + + if (good == 0) { +Index: net/third_party/nss/ssl/sslenum.c +=================================================================== +--- net/third_party/nss/ssl/sslenum.c (revision 217715) ++++ net/third_party/nss/ssl/sslenum.c (working copy) +@@ -29,6 +29,14 @@ + * Finally, update the ssl_V3_SUITES_IMPLEMENTED macro in sslimpl.h. + */ + const PRUint16 SSL_ImplementedCiphers[] = { ++ /* AES-GCM */ ++#ifdef NSS_ENABLE_ECC ++ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ++ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++#endif /* NSS_ENABLE_ECC */ ++ TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, ++ TLS_RSA_WITH_AES_128_GCM_SHA256, ++ + /* 256-bit */ + #ifdef NSS_ENABLE_ECC + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, +Index: net/third_party/nss/ssl/sslproto.h +=================================================================== +--- net/third_party/nss/ssl/sslproto.h (revision 217715) ++++ net/third_party/nss/ssl/sslproto.h (working copy) +@@ -162,6 +162,10 @@ + + #define TLS_RSA_WITH_SEED_CBC_SHA 0x0096 + ++#define TLS_RSA_WITH_AES_128_GCM_SHA256 0x009C ++#define TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 0x009E ++#define TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 0x00A2 ++ + /* TLS "Signaling Cipher Suite Value" (SCSV). May be requested by client. + * Must NEVER be chosen by server. SSL 3.0 server acknowledges by sending + * back an empty Renegotiation Info (RI) server hello extension. +@@ -204,6 +208,11 @@ + #define TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 0xC023 + #define TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 0xC027 + ++#define TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 0xC02B ++#define TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 0xC02D ++#define TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 0xC02F ++#define TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 0xC031 ++ + /* Netscape "experimental" cipher suites. */ + #define SSL_RSA_OLDFIPS_WITH_3DES_EDE_CBC_SHA 0xffe0 + #define SSL_RSA_OLDFIPS_WITH_DES_CBC_SHA 0xffe1 +Index: net/third_party/nss/ssl/sslt.h +=================================================================== +--- net/third_party/nss/ssl/sslt.h (revision 217715) ++++ net/third_party/nss/ssl/sslt.h (working copy) +@@ -91,9 +91,10 @@ + ssl_calg_3des = 4, + ssl_calg_idea = 5, + ssl_calg_fortezza = 6, /* deprecated, now unused */ +- ssl_calg_aes = 7, /* coming soon */ ++ ssl_calg_aes = 7, + ssl_calg_camellia = 8, +- ssl_calg_seed = 9 ++ ssl_calg_seed = 9, ++ ssl_calg_aes_gcm = 10 + } SSLCipherAlgorithm; + + typedef enum { +@@ -102,7 +103,8 @@ + ssl_mac_sha = 2, + ssl_hmac_md5 = 3, /* TLS HMAC version of mac_md5 */ + ssl_hmac_sha = 4, /* TLS HMAC version of mac_sha */ +- ssl_hmac_sha256 = 5 ++ ssl_hmac_sha256 = 5, ++ ssl_mac_aead = 6 + } SSLMACAlgorithm; + + typedef enum { +@@ -158,6 +160,9 @@ + PRUint16 effectiveKeyBits; + + /* MAC info */ ++ /* AEAD ciphers don't have a MAC. For an AEAD cipher, macAlgorithmName ++ * is "AEAD", macAlgorithm is ssl_mac_aead, and macBits is the length in ++ * bits of the authentication tag. */ + const char * macAlgorithmName; + SSLMACAlgorithm macAlgorithm; + PRUint16 macBits; diff --git a/chromium/net/third_party/nss/patches/aesgcmchromium.patch b/chromium/net/third_party/nss/patches/aesgcmchromium.patch new file mode 100644 index 00000000000..f9ec6cb4217 --- /dev/null +++ b/chromium/net/third_party/nss/patches/aesgcmchromium.patch @@ -0,0 +1,117 @@ +--- net/third_party/nss/ssl/ssl3con.c.orig 2013-08-20 12:00:16.742760827 -0700 ++++ net/third_party/nss/ssl/ssl3con.c 2013-08-20 11:59:56.782463207 -0700 +@@ -44,6 +44,9 @@ + #ifdef NSS_ENABLE_ZLIB + #include "zlib.h" + #endif ++#ifdef LINUX ++#include <dlfcn.h> ++#endif + + #ifndef PK11_SETATTRS + #define PK11_SETATTRS(x,id,v,l) (x)->type = (id); \ +@@ -1819,6 +1822,69 @@ ssl3_BuildRecordPseudoHeader(unsigned ch + return 13; + } + ++typedef SECStatus (*PK11CryptFcn)( ++ PK11SymKey *symKey, CK_MECHANISM_TYPE mechanism, SECItem *param, ++ unsigned char *out, unsigned int *outLen, unsigned int maxLen, ++ const unsigned char *in, unsigned int inLen); ++ ++static PK11CryptFcn pk11_encrypt = NULL; ++static PK11CryptFcn pk11_decrypt = NULL; ++ ++static PRCallOnceType resolvePK11CryptOnce; ++ ++static PRStatus ++ssl3_ResolvePK11CryptFunctions(void) ++{ ++#ifdef LINUX ++ /* On Linux we use the system NSS libraries. Look up the PK11_Encrypt and ++ * PK11_Decrypt functions at run time. */ ++ void *handle = dlopen(NULL, RTLD_LAZY); ++ if (!handle) { ++ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); ++ return PR_FAILURE; ++ } ++ pk11_encrypt = (PK11CryptFcn)dlsym(handle, "PK11_Encrypt"); ++ pk11_decrypt = (PK11CryptFcn)dlsym(handle, "PK11_Decrypt"); ++ dlclose(handle); ++ return PR_SUCCESS; ++#else ++ /* On other platforms we use our own copy of NSS. PK11_Encrypt and ++ * PK11_Decrypt are known to be available. */ ++ pk11_encrypt = PK11_Encrypt; ++ pk11_decrypt = PK11_Decrypt; ++ return PR_SUCCESS; ++#endif ++} ++ ++/* ++ * In NSS 3.15, PK11_Encrypt and PK11_Decrypt were added to provide access ++ * to the AES GCM implementation in the NSS softoken. So the presence of ++ * these two functions implies the NSS version supports AES GCM. ++ */ ++static PRBool ++ssl3_HasGCMSupport(void) ++{ ++ (void)PR_CallOnce(&resolvePK11CryptOnce, ssl3_ResolvePK11CryptFunctions); ++ return pk11_encrypt != NULL; ++} ++ ++/* On this socket, disable the GCM cipher suites */ ++SECStatus ++ssl3_DisableGCMSuites(sslSocket * ss) ++{ ++ unsigned int i; ++ ++ for (i = 0; i < PR_ARRAY_SIZE(cipher_suite_defs); i++) { ++ const ssl3CipherSuiteDef *cipher_def = &cipher_suite_defs[i]; ++ if (cipher_def->bulk_cipher_alg == cipher_aes_128_gcm) { ++ SECStatus rv = ssl3_CipherPrefSet(ss, cipher_def->cipher_suite, ++ PR_FALSE); ++ PORT_Assert(rv == SECSuccess); /* else is coding error */ ++ } ++ } ++ return SECSuccess; ++} ++ + static SECStatus + ssl3_AESGCM(ssl3KeyMaterial *keys, + PRBool doDecrypt, +@@ -1870,10 +1936,10 @@ ssl3_AESGCM(ssl3KeyMaterial *keys, + gcmParams.ulTagBits = tagSize * 8; + + if (doDecrypt) { +- rv = PK11_Decrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, ++ rv = pk11_decrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, + maxout, in, inlen); + } else { +- rv = PK11_Encrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, ++ rv = pk11_encrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, + maxout, in, inlen); + } + *outlen += (int) uOutLen; +@@ -5023,6 +5089,10 @@ ssl3_SendClientHello(sslSocket *ss, PRBo + ssl3_DisableNonDTLSSuites(ss); + } + ++ if (!ssl3_HasGCMSupport()) { ++ ssl3_DisableGCMSuites(ss); ++ } ++ + /* how many suites are permitted by policy and user preference? */ + num_suites = count_cipher_suites(ss, ss->ssl3.policy, PR_TRUE); + if (!num_suites) +@@ -7728,6 +7798,10 @@ ssl3_HandleClientHello(sslSocket *ss, SS + ssl3_DisableNonDTLSSuites(ss); + } + ++ if (!ssl3_HasGCMSupport()) { ++ ssl3_DisableGCMSuites(ss); ++ } ++ + #ifdef PARANOID + /* Look for a matching cipher suite. */ + j = ssl3_config_match_init(ss); diff --git a/chromium/net/third_party/nss/patches/applypatches.sh b/chromium/net/third_party/nss/patches/applypatches.sh index 68798aa6ebb..947cf5e1b57 100755 --- a/chromium/net/third_party/nss/patches/applypatches.sh +++ b/chromium/net/third_party/nss/patches/applypatches.sh @@ -48,4 +48,20 @@ patch -p4 < $patches_dir/alpn.patch patch -p5 < $patches_dir/sslsock_903565.patch +patch -p4 < $patches_dir/aesgcm.patch + +patch -p4 < $patches_dir/aesgcmchromium.patch + patch -p4 < $patches_dir/tls12backuphash.patch + +patch -p4 < $patches_dir/chacha20poly1305.patch + +patch -p4 < $patches_dir/cachelocks.patch + +patch -p4 < $patches_dir/ciphersuiteversion.patch + +patch -p4 < $patches_dir/peercertchain2.patch + +patch -p4 < $patches_dir/canfalsestart.patch + +patch -p4 < $patches_dir/nullcipher_934016.patch diff --git a/chromium/net/third_party/nss/patches/cachelocks.patch b/chromium/net/third_party/nss/patches/cachelocks.patch new file mode 100644 index 00000000000..5b3f93ed822 --- /dev/null +++ b/chromium/net/third_party/nss/patches/cachelocks.patch @@ -0,0 +1,246 @@ +diff --git a/nss/lib/ssl/ssl3con.c b/nss/lib/ssl/ssl3con.c +index 53c29f0..bc54c99 100644 +--- a/nss/lib/ssl/ssl3con.c ++++ b/nss/lib/ssl/ssl3con.c +@@ -5593,7 +5593,6 @@ SSL3_ShutdownServerCache(void) + } + + PZ_Unlock(symWrapKeysLock); +- ssl_FreeSessionCacheLocks(); + return SECSuccess; + } + +@@ -5645,7 +5644,7 @@ getWrappingKey( sslSocket * ss, + + pSymWrapKey = &symWrapKeys[symWrapMechIndex].symWrapKey[exchKeyType]; + +- ssl_InitSessionCacheLocks(PR_TRUE); ++ ssl_InitSessionCacheLocks(); + + PZ_Lock(symWrapKeysLock); + +diff --git a/nss/lib/ssl/sslimpl.h b/nss/lib/ssl/sslimpl.h +index e3ae9ce..59140f8 100644 +--- a/nss/lib/ssl/sslimpl.h ++++ b/nss/lib/ssl/sslimpl.h +@@ -1845,9 +1845,7 @@ extern SECStatus ssl_InitSymWrapKeysLock(void); + + extern SECStatus ssl_FreeSymWrapKeysLock(void); + +-extern SECStatus ssl_InitSessionCacheLocks(PRBool lazyInit); +- +-extern SECStatus ssl_FreeSessionCacheLocks(void); ++extern SECStatus ssl_InitSessionCacheLocks(void); + + /***************** platform client auth ****************/ + +diff --git a/nss/lib/ssl/sslnonce.c b/nss/lib/ssl/sslnonce.c +index 5d8a954..a6f7349 100644 +--- a/nss/lib/ssl/sslnonce.c ++++ b/nss/lib/ssl/sslnonce.c +@@ -35,91 +35,55 @@ static PZLock * cacheLock = NULL; + #define LOCK_CACHE lock_cache() + #define UNLOCK_CACHE PZ_Unlock(cacheLock) + ++static PRCallOnceType lockOnce; ++ ++/* FreeSessionCacheLocks is a callback from NSS_RegisterShutdown which destroys ++ * the session cache locks on shutdown and resets them to their initial ++ * state. */ + static SECStatus +-ssl_InitClientSessionCacheLock(void) ++FreeSessionCacheLocks(void* appData, void* nssData) + { ++ static const PRCallOnceType pristineCallOnce; ++ SECStatus rv; ++ ++ if (!cacheLock) { ++ PORT_SetError(SEC_ERROR_NOT_INITIALIZED); ++ return SECFailure; ++ } ++ ++ PZ_DestroyLock(cacheLock); ++ cacheLock = NULL; ++ ++ rv = ssl_FreeSymWrapKeysLock(); ++ if (rv != SECSuccess) { ++ return rv; ++ } ++ ++ lockOnce = pristineCallOnce; ++ return SECSuccess; ++} ++ ++/* InitSessionCacheLocks is called, protected by lockOnce, to create the ++ * session cache locks. */ ++static PRStatus ++InitSessionCacheLocks(void) ++{ ++ SECStatus rv; ++ + cacheLock = PZ_NewLock(nssILockCache); +- return cacheLock ? SECSuccess : SECFailure; +-} +- +-static SECStatus +-ssl_FreeClientSessionCacheLock(void) +-{ +- if (cacheLock) { ++ if (cacheLock == NULL) { ++ return PR_FAILURE; ++ } ++ rv = ssl_InitSymWrapKeysLock(); ++ if (rv != SECSuccess) { ++ PRErrorCode error = PORT_GetError(); + PZ_DestroyLock(cacheLock); + cacheLock = NULL; +- return SECSuccess; +- } +- PORT_SetError(SEC_ERROR_NOT_INITIALIZED); +- return SECFailure; +-} +- +-static PRBool LocksInitializedEarly = PR_FALSE; +- +-static SECStatus +-FreeSessionCacheLocks() +-{ +- SECStatus rv1, rv2; +- rv1 = ssl_FreeSymWrapKeysLock(); +- rv2 = ssl_FreeClientSessionCacheLock(); +- if ( (SECSuccess == rv1) && (SECSuccess == rv2) ) { +- return SECSuccess; +- } +- return SECFailure; +-} +- +-static SECStatus +-InitSessionCacheLocks(void) +-{ +- SECStatus rv1, rv2; +- PRErrorCode rc; +- rv1 = ssl_InitSymWrapKeysLock(); +- rv2 = ssl_InitClientSessionCacheLock(); +- if ( (SECSuccess == rv1) && (SECSuccess == rv2) ) { +- return SECSuccess; +- } +- rc = PORT_GetError(); +- FreeSessionCacheLocks(); +- PORT_SetError(rc); +- return SECFailure; +-} +- +-/* free the session cache locks if they were initialized early */ +-SECStatus +-ssl_FreeSessionCacheLocks() +-{ +- PORT_Assert(PR_TRUE == LocksInitializedEarly); +- if (!LocksInitializedEarly) { +- PORT_SetError(SEC_ERROR_NOT_INITIALIZED); +- return SECFailure; +- } +- FreeSessionCacheLocks(); +- LocksInitializedEarly = PR_FALSE; +- return SECSuccess; +-} +- +-static PRCallOnceType lockOnce; +- +-/* free the session cache locks if they were initialized lazily */ +-static SECStatus ssl_ShutdownLocks(void* appData, void* nssData) +-{ +- PORT_Assert(PR_FALSE == LocksInitializedEarly); +- if (LocksInitializedEarly) { +- PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); +- return SECFailure; +- } +- FreeSessionCacheLocks(); +- memset(&lockOnce, 0, sizeof(lockOnce)); +- return SECSuccess; +-} +- +-static PRStatus initSessionCacheLocksLazily(void) +-{ +- SECStatus rv = InitSessionCacheLocks(); +- if (SECSuccess != rv) { ++ PORT_SetError(error); + return PR_FAILURE; + } +- rv = NSS_RegisterShutdown(ssl_ShutdownLocks, NULL); ++ ++ rv = NSS_RegisterShutdown(FreeSessionCacheLocks, NULL); + PORT_Assert(SECSuccess == rv); + if (SECSuccess != rv) { + return PR_FAILURE; +@@ -127,34 +91,18 @@ static PRStatus initSessionCacheLocksLazily(void) + return PR_SUCCESS; + } + +-/* lazyInit means that the call is not happening during a 1-time +- * initialization function, but rather during dynamic, lazy initialization +- */ + SECStatus +-ssl_InitSessionCacheLocks(PRBool lazyInit) ++ssl_InitSessionCacheLocks(void) + { +- if (LocksInitializedEarly) { +- return SECSuccess; +- } +- +- if (lazyInit) { +- return (PR_SUCCESS == +- PR_CallOnce(&lockOnce, initSessionCacheLocksLazily)) ? +- SECSuccess : SECFailure; +- } +- +- if (SECSuccess == InitSessionCacheLocks()) { +- LocksInitializedEarly = PR_TRUE; +- return SECSuccess; +- } +- +- return SECFailure; ++ return (PR_SUCCESS == ++ PR_CallOnce(&lockOnce, InitSessionCacheLocks)) ? ++ SECSuccess : SECFailure; + } + +-static void ++static void + lock_cache(void) + { +- ssl_InitSessionCacheLocks(PR_TRUE); ++ ssl_InitSessionCacheLocks(); + PZ_Lock(cacheLock); + } + +diff --git a/nss/lib/ssl/sslsnce.c b/nss/lib/ssl/sslsnce.c +index b0446ad..34e07b0 100644 +--- a/nss/lib/ssl/sslsnce.c ++++ b/nss/lib/ssl/sslsnce.c +@@ -1353,7 +1353,7 @@ SSL_ConfigServerSessionIDCache( int maxCacheEntries, + PRUint32 ssl3_timeout, + const char * directory) + { +- ssl_InitSessionCacheLocks(PR_FALSE); ++ ssl_InitSessionCacheLocks(); + return SSL_ConfigServerSessionIDCacheInstance(&globalCache, + maxCacheEntries, ssl2_timeout, ssl3_timeout, directory, PR_FALSE); + } +@@ -1467,7 +1467,7 @@ SSL_ConfigServerSessionIDCacheWithOpt( + PRBool enableMPCache) + { + if (!enableMPCache) { +- ssl_InitSessionCacheLocks(PR_FALSE); ++ ssl_InitSessionCacheLocks(); + return ssl_ConfigServerSessionIDCacheInstanceWithOpt(&globalCache, + ssl2_timeout, ssl3_timeout, directory, PR_FALSE, + maxCacheEntries, maxCertCacheEntries, maxSrvNameCacheEntries); +@@ -1512,7 +1512,7 @@ SSL_InheritMPServerSIDCacheInstance(cacheDesc *cache, const char * envString) + return SECSuccess; /* already done. */ + } + +- ssl_InitSessionCacheLocks(PR_FALSE); ++ ssl_InitSessionCacheLocks(); + + ssl_sid_lookup = ServerSessionIDLookup; + ssl_sid_cache = ServerSessionIDCache; diff --git a/chromium/net/third_party/nss/patches/canfalsestart.patch b/chromium/net/third_party/nss/patches/canfalsestart.patch new file mode 100644 index 00000000000..d2a9752c070 --- /dev/null +++ b/chromium/net/third_party/nss/patches/canfalsestart.patch @@ -0,0 +1,637 @@ +Index: net/third_party/nss/ssl/ssl.h +=================================================================== +--- net/third_party/nss/ssl/ssl.h (revision 227363) ++++ net/third_party/nss/ssl/ssl.h (working copy) +@@ -121,14 +121,22 @@ + #define SSL_ENABLE_FALSE_START 22 /* Enable SSL false start (off by */ + /* default, applies only to */ + /* clients). False start is a */ +-/* mode where an SSL client will start sending application data before */ +-/* verifying the server's Finished message. This means that we could end up */ +-/* sending data to an imposter. However, the data will be encrypted and */ +-/* only the true server can derive the session key. Thus, so long as the */ +-/* cipher isn't broken this is safe. Because of this, False Start will only */ +-/* occur on RSA or DH ciphersuites where the cipher's key length is >= 80 */ +-/* bits. The advantage of False Start is that it saves a round trip for */ +-/* client-speaks-first protocols when performing a full handshake. */ ++/* mode where an SSL client will start sending application data before ++ * verifying the server's Finished message. This means that we could end up ++ * sending data to an imposter. However, the data will be encrypted and ++ * only the true server can derive the session key. Thus, so long as the ++ * cipher isn't broken this is safe. The advantage of false start is that ++ * it saves a round trip for client-speaks-first protocols when performing a ++ * full handshake. ++ * ++ * See SSL_DefaultCanFalseStart for the default criteria that NSS uses to ++ * determine whether to false start or not. See SSL_SetCanFalseStartCallback ++ * for how to change that criteria. In addition to those criteria, false start ++ * will only be done when the server selects a cipher suite with an effective ++ * key length of 80 bits or more (including RC4-128). Also, see ++ * SSL_HandshakeCallback for a description on how false start affects when the ++ * handshake callback gets called. ++ */ + + /* For SSL 3.0 and TLS 1.0, by default we prevent chosen plaintext attacks + * on SSL CBC mode cipher suites (see RFC 4346 Section F.3) by splitting +@@ -741,14 +749,59 @@ + SSL_IMPORT SECStatus SSL_InheritMPServerSIDCache(const char * envString); + + /* +-** Set the callback on a particular socket that gets called when we finish +-** performing a handshake. ++** Set the callback that normally gets called when the TLS handshake ++** is complete. If false start is not enabled, then the handshake callback is ++** called after verifying the peer's Finished message and before sending ++** outgoing application data and before processing incoming application data. ++** ++** If false start is enabled and there is a custom CanFalseStartCallback ++** callback set, then the handshake callback gets called after the peer's ++** Finished message has been verified, which may be after application data is ++** sent. ++** ++** If false start is enabled and there is not a custom CanFalseStartCallback ++** callback established with SSL_SetCanFalseStartCallback then the handshake ++** callback gets called before any application data is sent, which may be ++** before the peer's Finished message has been verified. + */ + typedef void (PR_CALLBACK *SSLHandshakeCallback)(PRFileDesc *fd, + void *client_data); + SSL_IMPORT SECStatus SSL_HandshakeCallback(PRFileDesc *fd, + SSLHandshakeCallback cb, void *client_data); + ++/* Applications that wish to customize TLS false start should set this callback ++** function. NSS will invoke the functon to determine if a particular ++** connection should use false start or not. SECSuccess indicates that the ++** callback completed successfully, and if so *canFalseStart indicates if false ++** start can be used. If the callback does not return SECSuccess then the ++** handshake will be canceled. ++** ++** Applications that do not set the callback will use an internal set of ++** criteria to determine if the connection should false start. If ++** the callback is set false start will never be used without invoking the ++** callback function, but some connections (e.g. resumed connections) will ++** never use false start and therefore will not invoke the callback. ++** ++** NSS's internal criteria for this connection can be evaluated by calling ++** SSL_DefaultCanFalseStart() from the custom callback. ++** ++** See the description of SSL_HandshakeCallback for important information on ++** how registering a custom false start callback affects when the handshake ++** callback gets called. ++**/ ++typedef SECStatus (PR_CALLBACK *SSLCanFalseStartCallback)( ++ PRFileDesc *fd, void *arg, PRBool *canFalseStart); ++ ++SSL_IMPORT SECStatus SSL_SetCanFalseStartCallback( ++ PRFileDesc *fd, SSLCanFalseStartCallback callback, void *arg); ++ ++/* A utility function that can be called from a custom CanFalseStartCallback ++** function to determine what NSS would have done for this connection if the ++** custom callback was not implemented. ++**/ ++SSL_IMPORT SECStatus SSL_DefaultCanFalseStart(PRFileDesc *fd, ++ PRBool *canFalseStart); ++ + /* + ** For the server, request a new handshake. For the client, begin a new + ** handshake. If flushCache is non-zero, the SSL3 cache entry will be +Index: net/third_party/nss/ssl/ssl3gthr.c +=================================================================== +--- net/third_party/nss/ssl/ssl3gthr.c (revision 227363) ++++ net/third_party/nss/ssl/ssl3gthr.c (working copy) +@@ -374,9 +374,7 @@ + */ + if (ss->opt.enableFalseStart) { + ssl_GetSSL3HandshakeLock(ss); +- canFalseStart = (ss->ssl3.hs.ws == wait_change_cipher || +- ss->ssl3.hs.ws == wait_new_session_ticket) && +- ssl3_CanFalseStart(ss); ++ canFalseStart = ss->ssl3.hs.canFalseStart; + ssl_ReleaseSSL3HandshakeLock(ss); + } + } while (ss->ssl3.hs.ws != idle_handshake && +Index: net/third_party/nss/ssl/sslinfo.c +=================================================================== +--- net/third_party/nss/ssl/sslinfo.c (revision 227363) ++++ net/third_party/nss/ssl/sslinfo.c (working copy) +@@ -26,7 +26,6 @@ + sslSocket * ss; + SSLChannelInfo inf; + sslSessionID * sid; +- PRBool enoughFirstHsDone = PR_FALSE; + + if (!info || len < sizeof inf.length) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); +@@ -43,14 +42,7 @@ + memset(&inf, 0, sizeof inf); + inf.length = PR_MIN(sizeof inf, len); + +- if (ss->firstHsDone) { +- enoughFirstHsDone = PR_TRUE; +- } else if (ss->version >= SSL_LIBRARY_VERSION_3_0 && +- ssl3_CanFalseStart(ss)) { +- enoughFirstHsDone = PR_TRUE; +- } +- +- if (ss->opt.useSecurity && enoughFirstHsDone) { ++ if (ss->opt.useSecurity && ss->enoughFirstHsDone) { + sid = ss->sec.ci.sid; + inf.protocolVersion = ss->version; + inf.authKeyBits = ss->sec.authKeyBits; +Index: net/third_party/nss/ssl/sslauth.c +=================================================================== +--- net/third_party/nss/ssl/sslauth.c (revision 227363) ++++ net/third_party/nss/ssl/sslauth.c (working copy) +@@ -100,7 +100,6 @@ + sslSocket *ss; + const char *cipherName; + PRBool isDes = PR_FALSE; +- PRBool enoughFirstHsDone = PR_FALSE; + + ss = ssl_FindSocket(fd); + if (!ss) { +@@ -118,14 +117,7 @@ + *op = SSL_SECURITY_STATUS_OFF; + } + +- if (ss->firstHsDone) { +- enoughFirstHsDone = PR_TRUE; +- } else if (ss->version >= SSL_LIBRARY_VERSION_3_0 && +- ssl3_CanFalseStart(ss)) { +- enoughFirstHsDone = PR_TRUE; +- } +- +- if (ss->opt.useSecurity && enoughFirstHsDone) { ++ if (ss->opt.useSecurity && ss->enoughFirstHsDone) { + if (ss->version < SSL_LIBRARY_VERSION_3_0) { + cipherName = ssl_cipherName[ss->sec.cipherType]; + } else { +Index: net/third_party/nss/ssl/sslimpl.h +=================================================================== +--- net/third_party/nss/ssl/sslimpl.h (revision 227363) ++++ net/third_party/nss/ssl/sslimpl.h (working copy) +@@ -881,6 +881,8 @@ + /* Shared state between ssl3_HandleFinished and ssl3_FinishHandshake */ + PRBool cacheSID; + ++ PRBool canFalseStart; /* Can/did we False Start */ ++ + /* clientSigAndHash contains the contents of the signature_algorithms + * extension (if any) from the client. This is only valid for TLS 1.2 + * or later. */ +@@ -1162,6 +1164,10 @@ + unsigned long clientAuthRequested; + unsigned long delayDisabled; /* Nagle delay disabled */ + unsigned long firstHsDone; /* first handshake is complete. */ ++ unsigned long enoughFirstHsDone; /* enough of the first handshake is ++ * done for callbacks to be able to ++ * retrieve channel security ++ * parameters from the SSL socket. */ + unsigned long handshakeBegun; + unsigned long lastWriteBlocked; + unsigned long recvdCloseNotify; /* received SSL EOF. */ +@@ -1210,6 +1216,8 @@ + void *badCertArg; + SSLHandshakeCallback handshakeCallback; + void *handshakeCallbackData; ++ SSLCanFalseStartCallback canFalseStartCallback; ++ void *canFalseStartCallbackData; + void *pkcs11PinArg; + SSLNextProtoCallback nextProtoCallback; + void *nextProtoArg; +@@ -1423,7 +1431,6 @@ + + extern SECStatus ssl_EnableNagleDelay(sslSocket *ss, PRBool enabled); + +-extern PRBool ssl3_CanFalseStart(sslSocket *ss); + extern SECStatus + ssl3_CompressMACEncryptRecord(ssl3CipherSpec * cwSpec, + PRBool isServer, +Index: net/third_party/nss/ssl/sslsecur.c +=================================================================== +--- net/third_party/nss/ssl/sslsecur.c (revision 227363) ++++ net/third_party/nss/ssl/sslsecur.c (working copy) +@@ -99,21 +99,12 @@ + if (ss->handshake == 0) { + ssl_GetRecvBufLock(ss); + ss->gs.recordLen = 0; ++ ss->gs.writeOffset = 0; ++ ss->gs.readOffset = 0; + ssl_ReleaseRecvBufLock(ss); + + SSL_TRC(3, ("%d: SSL[%d]: handshake is completed", + SSL_GETPID(), ss->fd)); +- /* call handshake callback for ssl v2 */ +- /* for v3 this is done in ssl3_HandleFinished() */ +- if ((ss->handshakeCallback != NULL) && /* has callback */ +- (!ss->firstHsDone) && /* only first time */ +- (ss->version < SSL_LIBRARY_VERSION_3_0)) { /* not ssl3 */ +- ss->firstHsDone = PR_TRUE; +- (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); +- } +- ss->firstHsDone = PR_TRUE; +- ss->gs.writeOffset = 0; +- ss->gs.readOffset = 0; + break; + } + rv = (*ss->handshake)(ss); +@@ -206,6 +197,7 @@ + ssl_Get1stHandshakeLock(ss); + + ss->firstHsDone = PR_FALSE; ++ ss->enoughFirstHsDone = PR_FALSE; + if ( asServer ) { + ss->handshake = ssl2_BeginServerHandshake; + ss->handshaking = sslHandshakingAsServer; +@@ -221,6 +213,8 @@ + ssl_ReleaseRecvBufLock(ss); + + ssl_GetSSL3HandshakeLock(ss); ++ ss->ssl3.hs.canFalseStart = PR_FALSE; ++ ss->ssl3.hs.restartTarget = NULL; + + /* + ** Blow away old security state and get a fresh setup. +@@ -266,7 +260,7 @@ + + /* SSL v2 protocol does not support subsequent handshakes. */ + if (ss->version < SSL_LIBRARY_VERSION_3_0) { +- PORT_SetError(SEC_ERROR_INVALID_ARGS); ++ PORT_SetError(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_SSL2); + rv = SECFailure; + } else { + ssl_GetSSL3HandshakeLock(ss); +@@ -331,6 +325,75 @@ + return SECSuccess; + } + ++/* Register an application callback to be called when false start may happen. ++** Acquires and releases HandshakeLock. ++*/ ++SECStatus ++SSL_SetCanFalseStartCallback(PRFileDesc *fd, SSLCanFalseStartCallback cb, ++ void *client_data) ++{ ++ sslSocket *ss; ++ ++ ss = ssl_FindSocket(fd); ++ if (!ss) { ++ SSL_DBG(("%d: SSL[%d]: bad socket in SSL_SetCanFalseStartCallback", ++ SSL_GETPID(), fd)); ++ return SECFailure; ++ } ++ ++ if (!ss->opt.useSecurity) { ++ PORT_SetError(SEC_ERROR_INVALID_ARGS); ++ return SECFailure; ++ } ++ ++ ssl_Get1stHandshakeLock(ss); ++ ssl_GetSSL3HandshakeLock(ss); ++ ++ ss->canFalseStartCallback = cb; ++ ss->canFalseStartCallbackData = client_data; ++ ++ ssl_ReleaseSSL3HandshakeLock(ss); ++ ssl_Release1stHandshakeLock(ss); ++ ++ return SECSuccess; ++} ++ ++/* A utility function that can be called from a custom SSLCanFalseStartCallback ++** function to determine what NSS would have done for this connection if the ++** custom callback was not implemented. ++*/ ++SECStatus ++SSL_DefaultCanFalseStart(PRFileDesc *fd, PRBool *canFalseStart) ++{ ++ sslSocket *ss; ++ ++ *canFalseStart = PR_FALSE; ++ ss = ssl_FindSocket(fd); ++ if (!ss) { ++ SSL_DBG(("%d: SSL[%d]: bad socket in SSL_DefaultCanFalseStart", ++ SSL_GETPID(), fd)); ++ return SECFailure; ++ } ++ ++ if (!ss->ssl3.initialized) { ++ PORT_SetError(SEC_ERROR_INVALID_ARGS); ++ return SECFailure; ++ } ++ ++ if (ss->version < SSL_LIBRARY_VERSION_3_0) { ++ PORT_SetError(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_SSL2); ++ return SECFailure; ++ } ++ ++ /* Require a forward-secret key exchange. */ ++ *canFalseStart = ss->ssl3.hs.kea_def->kea == kea_dhe_dss || ++ ss->ssl3.hs.kea_def->kea == kea_dhe_rsa || ++ ss->ssl3.hs.kea_def->kea == kea_ecdhe_ecdsa || ++ ss->ssl3.hs.kea_def->kea == kea_ecdhe_rsa; ++ ++ return SECSuccess; ++} ++ + /* Try to make progress on an SSL handshake by attempting to read the + ** next handshake from the peer, and sending any responses. + ** For non-blocking sockets, returns PR_ERROR_WOULD_BLOCK if it cannot +@@ -1195,12 +1258,7 @@ + ssl_Get1stHandshakeLock(ss); + if (ss->version >= SSL_LIBRARY_VERSION_3_0) { + ssl_GetSSL3HandshakeLock(ss); +- if ((ss->ssl3.hs.ws == wait_change_cipher || +- ss->ssl3.hs.ws == wait_finished || +- ss->ssl3.hs.ws == wait_new_session_ticket) && +- ssl3_CanFalseStart(ss)) { +- canFalseStart = PR_TRUE; +- } ++ canFalseStart = ss->ssl3.hs.canFalseStart; + ssl_ReleaseSSL3HandshakeLock(ss); + } + if (!canFalseStart && +Index: net/third_party/nss/ssl/sslsock.c +=================================================================== +--- net/third_party/nss/ssl/sslsock.c (revision 227363) ++++ net/third_party/nss/ssl/sslsock.c (working copy) +@@ -2457,10 +2457,14 @@ + } else if (new_flags & PR_POLL_WRITE) { + /* The caller is trying to write, but the handshake is + ** blocked waiting for data to read, and the first +- ** handshake has been sent. so do NOT to poll on write. ++ ** handshake has been sent. So do NOT to poll on write ++ ** unless we did false start. + */ +- new_flags ^= PR_POLL_WRITE; /* don't select on write. */ +- new_flags |= PR_POLL_READ; /* do select on read. */ ++ if (!(ss->version >= SSL_LIBRARY_VERSION_3_0 && ++ ss->ssl3.hs.canFalseStart)) { ++ new_flags ^= PR_POLL_WRITE; /* don't select on write. */ ++ } ++ new_flags |= PR_POLL_READ; /* do select on read. */ + } + } + } else if ((new_flags & PR_POLL_READ) && (SSL_DataPending(fd) > 0)) { +Index: net/third_party/nss/ssl/ssl3con.c +=================================================================== +--- net/third_party/nss/ssl/ssl3con.c (revision 227363) ++++ net/third_party/nss/ssl/ssl3con.c (working copy) +@@ -2890,7 +2890,7 @@ + SSL_TRC(3, ("%d: SSL3[%d] SendRecord type: %s nIn=%d", + SSL_GETPID(), ss->fd, ssl3_DecodeContentType(type), + nIn)); +- PRINT_BUF(3, (ss, "Send record (plain text)", pIn, nIn)); ++ PRINT_BUF(50, (ss, "Send record (plain text)", pIn, nIn)); + + PORT_Assert( ss->opt.noLocks || ssl_HaveXmitBufLock(ss) ); + +@@ -7344,35 +7344,42 @@ + return rv; + } + +-PRBool +-ssl3_CanFalseStart(sslSocket *ss) { +- PRBool rv; ++static SECStatus ++ssl3_CheckFalseStart(sslSocket *ss) ++{ ++ SECStatus rv; ++ PRBool maybeFalseStart = PR_TRUE; + + PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) ); ++ PORT_Assert( !ss->ssl3.hs.authCertificatePending ); + +- /* XXX: does not take into account whether we are waiting for +- * SSL_AuthCertificateComplete or SSL_RestartHandshakeAfterCertReq. If/when +- * that is done, this function could return different results each time it +- * would be called. +- */ ++ /* An attacker can control the selected ciphersuite so we only wish to ++ * do False Start in the case that the selected ciphersuite is ++ * sufficiently strong that the attack can gain no advantage. ++ * Therefore we always require an 80-bit cipher. */ + + ssl_GetSpecReadLock(ss); +- rv = ss->opt.enableFalseStart && +- !ss->sec.isServer && +- !ss->ssl3.hs.isResuming && +- ss->ssl3.cwSpec && ++ if (ss->ssl3.cwSpec->cipher_def->secret_key_size < 10) { ++ ss->ssl3.hs.canFalseStart = PR_FALSE; ++ maybeFalseStart = PR_FALSE; ++ } ++ ssl_ReleaseSpecReadLock(ss); ++ if (!maybeFalseStart) { ++ return SECSuccess; ++ } + +- /* An attacker can control the selected ciphersuite so we only wish to +- * do False Start in the case that the selected ciphersuite is +- * sufficiently strong that the attack can gain no advantage. +- * Therefore we require an 80-bit cipher and a forward-secret key +- * exchange. */ +- ss->ssl3.cwSpec->cipher_def->secret_key_size >= 10 && +- (ss->ssl3.hs.kea_def->kea == kea_dhe_dss || +- ss->ssl3.hs.kea_def->kea == kea_dhe_rsa || +- ss->ssl3.hs.kea_def->kea == kea_ecdhe_ecdsa || +- ss->ssl3.hs.kea_def->kea == kea_ecdhe_rsa); +- ssl_ReleaseSpecReadLock(ss); ++ if (!ss->canFalseStartCallback) { ++ rv = SSL_DefaultCanFalseStart(ss->fd, &ss->ssl3.hs.canFalseStart); ++ } else { ++ rv = (ss->canFalseStartCallback)(ss->fd, ++ ss->canFalseStartCallbackData, ++ &ss->ssl3.hs.canFalseStart); ++ } ++ ++ if (rv != SECSuccess) { ++ ss->ssl3.hs.canFalseStart = PR_FALSE; ++ } ++ + return rv; + } + +@@ -7500,20 +7507,59 @@ + goto loser; /* err code was set. */ + } + +- /* XXX: If the server's certificate hasn't been authenticated by this +- * point, then we may be leaking this NPN message to an attacker. ++ /* This must be done after we've set ss->ssl3.cwSpec in ++ * ssl3_SendChangeCipherSpecs because SSL_GetChannelInfo uses information ++ * from cwSpec. This must be done before we call ssl3_CheckFalseStart ++ * because the false start callback (if any) may need the information from ++ * the functions that depend on this being set. + */ ++ ss->enoughFirstHsDone = PR_TRUE; ++ + if (!ss->firstHsDone) { ++ /* XXX: If the server's certificate hasn't been authenticated by this ++ * point, then we may be leaking this NPN message to an attacker. ++ */ + rv = ssl3_SendNextProto(ss); + if (rv != SECSuccess) { + goto loser; /* err code was set. */ + } + } ++ + rv = ssl3_SendEncryptedExtensions(ss); + if (rv != SECSuccess) { + goto loser; /* err code was set. */ + } + ++ if (!ss->firstHsDone) { ++ if (ss->opt.enableFalseStart) { ++ if (!ss->ssl3.hs.authCertificatePending) { ++ /* When we fix bug 589047, we will need to know whether we are ++ * false starting before we try to flush the client second ++ * round to the network. With that in mind, we purposefully ++ * call ssl3_CheckFalseStart before calling ssl3_SendFinished, ++ * which includes a call to ssl3_FlushHandshake, so that ++ * no application develops a reliance on such flushing being ++ * done before its false start callback is called. ++ */ ++ ssl_ReleaseXmitBufLock(ss); ++ rv = ssl3_CheckFalseStart(ss); ++ ssl_GetXmitBufLock(ss); ++ if (rv != SECSuccess) { ++ goto loser; ++ } ++ } else { ++ /* The certificate authentication and the server's Finished ++ * message are racing each other. If the certificate ++ * authentication wins, then we will try to false start in ++ * ssl3_AuthCertificateComplete. ++ */ ++ SSL_TRC(3, ("%d: SSL3[%p]: deferring false start check because" ++ " certificate authentication is still pending.", ++ SSL_GETPID(), ss->fd)); ++ } ++ } ++ } ++ + rv = ssl3_SendFinished(ss, 0); + if (rv != SECSuccess) { + goto loser; /* err code was set. */ +@@ -7526,8 +7572,16 @@ + else + ss->ssl3.hs.ws = wait_change_cipher; + +- /* Do the handshake callback for sslv3 here, if we can false start. */ +- if (ss->handshakeCallback != NULL && ssl3_CanFalseStart(ss)) { ++ if (ss->handshakeCallback && ++ (ss->ssl3.hs.canFalseStart && !ss->canFalseStartCallback)) { ++ /* Call the handshake callback here for backwards compatibility with ++ * applications that were using false start before ++ * canFalseStartCallback was added. Note that we do this after calling ++ * ssl3_SendFinished, which includes a call to ssl3_FlushHandshake, ++ * just in case the application is relying on having the handshake ++ * messages flushed to the network before its handshake callback is ++ * called. ++ */ + (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); + } + +@@ -10147,13 +10201,6 @@ + + ss->ssl3.hs.authCertificatePending = PR_TRUE; + rv = SECSuccess; +- +- /* XXX: Async cert validation and False Start don't work together +- * safely yet; if we leave False Start enabled, we may end up false +- * starting (sending application data) before we +- * SSL_AuthCertificateComplete has been called. +- */ +- ss->opt.enableFalseStart = PR_FALSE; + } + + if (rv != SECSuccess) { +@@ -10278,6 +10325,12 @@ + } else if (ss->ssl3.hs.restartTarget != NULL) { + sslRestartTarget target = ss->ssl3.hs.restartTarget; + ss->ssl3.hs.restartTarget = NULL; ++ ++ if (target == ssl3_FinishHandshake) { ++ SSL_TRC(3,("%d: SSL3[%p]: certificate authentication lost the race" ++ " with peer's finished message", SSL_GETPID(), ss->fd)); ++ } ++ + rv = target(ss); + /* Even if we blocked here, we have accomplished enough to claim + * success. Any remaining work will be taken care of by subsequent +@@ -10287,7 +10340,39 @@ + rv = SECSuccess; + } + } else { +- rv = SECSuccess; ++ SSL_TRC(3, ("%d: SSL3[%p]: certificate authentication won the race" ++ " with peer's finished message", SSL_GETPID(), ss->fd)); ++ ++ PORT_Assert(!ss->firstHsDone); ++ PORT_Assert(!ss->sec.isServer); ++ PORT_Assert(!ss->ssl3.hs.isResuming); ++ PORT_Assert(ss->ssl3.hs.ws == wait_change_cipher || ++ ss->ssl3.hs.ws == wait_finished || ++ ss->ssl3.hs.ws == wait_new_session_ticket); ++ ++ /* ssl3_SendClientSecondRound deferred the false start check because ++ * certificate authentication was pending, so we have to do it now. ++ */ ++ if (ss->opt.enableFalseStart && ++ !ss->firstHsDone && ++ !ss->sec.isServer && ++ !ss->ssl3.hs.isResuming && ++ (ss->ssl3.hs.ws == wait_change_cipher || ++ ss->ssl3.hs.ws == wait_finished || ++ ss->ssl3.hs.ws == wait_new_session_ticket)) { ++ rv = ssl3_CheckFalseStart(ss); ++ if (rv == SECSuccess && ++ ss->handshakeCallback && ++ (ss->ssl3.hs.canFalseStart && !ss->canFalseStartCallback)) { ++ /* Call the handshake callback here for backwards compatibility ++ * with applications that were using false start before ++ * canFalseStartCallback was added. ++ */ ++ (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); ++ } ++ } else { ++ rv = SECSuccess; ++ } + } + + done: +@@ -10983,6 +11068,8 @@ + SECStatus + ssl3_FinishHandshake(sslSocket * ss) + { ++ PRBool falseStarted; ++ + PORT_Assert( ss->opt.noLocks || ssl_HaveRecvBufLock(ss) ); + PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) ); + PORT_Assert( ss->ssl3.hs.restartTarget == NULL ); +@@ -10990,6 +11077,7 @@ + /* The first handshake is now completed. */ + ss->handshake = NULL; + ss->firstHsDone = PR_TRUE; ++ ss->enoughFirstHsDone = PR_TRUE; + + if (ss->ssl3.hs.cacheSID) { + (*ss->sec.cache)(ss->sec.ci.sid); +@@ -10997,9 +11085,14 @@ + } + + ss->ssl3.hs.ws = idle_handshake; ++ falseStarted = ss->ssl3.hs.canFalseStart; ++ ss->ssl3.hs.canFalseStart = PR_FALSE; /* False Start phase is complete */ + +- /* Do the handshake callback for sslv3 here, if we cannot false start. */ +- if (ss->handshakeCallback != NULL && !ssl3_CanFalseStart(ss)) { ++ /* Call the handshake callback for sslv3 here, unless we called it already ++ * for the case where false start was done without a canFalseStartCallback. ++ */ ++ if (ss->handshakeCallback && ++ !(falseStarted && !ss->canFalseStartCallback)) { + (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); + } + diff --git a/chromium/net/third_party/nss/patches/chacha20poly1305.patch b/chromium/net/third_party/nss/patches/chacha20poly1305.patch new file mode 100644 index 00000000000..c858413f3c3 --- /dev/null +++ b/chromium/net/third_party/nss/patches/chacha20poly1305.patch @@ -0,0 +1,280 @@ +diff --git a/nss/lib/ssl/ssl3con.c b/nss/lib/ssl/ssl3con.c +index 8be517c..53c29f0 100644 +--- a/nss/lib/ssl/ssl3con.c ++++ b/nss/lib/ssl/ssl3con.c +@@ -40,6 +40,21 @@ + #define CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256 (CKM_NSS + 24) + #endif + ++/* This is a bodge to allow this code to be compiled against older NSS ++ * headers. */ ++#ifndef CKM_NSS_CHACHA20_POLY1305 ++#define CKM_NSS_CHACHA20_POLY1305 (CKM_NSS + 25) ++ ++typedef struct CK_AEAD_PARAMS { ++ CK_BYTE_PTR pIv; /* This is the nonce. */ ++ CK_ULONG ulIvLen; ++ CK_BYTE_PTR pAAD; ++ CK_ULONG ulAADLen; ++ CK_ULONG ulTagBits; ++} CK_AEAD_PARAMS; ++ ++#endif ++ + #include <stdio.h> + #ifdef NSS_ENABLE_ZLIB + #include "zlib.h" +@@ -100,6 +115,8 @@ static SECStatus ssl3_AESGCMBypass(ssl3KeyMaterial *keys, PRBool doDecrypt, + static ssl3CipherSuiteCfg cipherSuites[ssl_V3_SUITES_IMPLEMENTED] = { + /* cipher_suite policy enabled is_present*/ + #ifdef NSS_ENABLE_ECC ++ { TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, ++ { TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + #endif /* NSS_ENABLE_ECC */ +@@ -273,6 +290,7 @@ static const ssl3BulkCipherDef bulk_cipher_defs[] = { + {cipher_camellia_256, calg_camellia, 32,32, type_block, 16,16, 0, 0}, + {cipher_seed, calg_seed, 16,16, type_block, 16,16, 0, 0}, + {cipher_aes_128_gcm, calg_aes_gcm, 16,16, type_aead, 4, 0,16, 8}, ++ {cipher_chacha20, calg_chacha20, 32,32, type_aead, 0, 0,16, 0}, + {cipher_missing, calg_null, 0, 0, type_stream, 0, 0, 0, 0}, + }; + +@@ -399,6 +417,8 @@ static const ssl3CipherSuiteDef cipher_suite_defs[] = + {TLS_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_rsa}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_ecdhe_rsa}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_ecdhe_ecdsa}, ++ {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, cipher_chacha20, mac_aead, kea_ecdhe_rsa}, ++ {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, cipher_chacha20, mac_aead, kea_ecdhe_ecdsa}, + + #ifdef NSS_ENABLE_ECC + {TLS_ECDH_ECDSA_WITH_NULL_SHA, cipher_null, mac_sha, kea_ecdh_ecdsa}, +@@ -464,6 +484,7 @@ static const SSLCipher2Mech alg2Mech[] = { + { calg_camellia , CKM_CAMELLIA_CBC }, + { calg_seed , CKM_SEED_CBC }, + { calg_aes_gcm , CKM_AES_GCM }, ++ { calg_chacha20 , CKM_NSS_CHACHA20_POLY1305 }, + /* { calg_init , (CK_MECHANISM_TYPE)0x7fffffffL } */ + }; + +@@ -2020,6 +2041,46 @@ ssl3_AESGCMBypass(ssl3KeyMaterial *keys, + } + #endif + ++static SECStatus ++ssl3_ChaCha20Poly1305( ++ ssl3KeyMaterial *keys, ++ PRBool doDecrypt, ++ unsigned char *out, ++ int *outlen, ++ int maxout, ++ const unsigned char *in, ++ int inlen, ++ const unsigned char *additionalData, ++ int additionalDataLen) ++{ ++ SECItem param; ++ SECStatus rv = SECFailure; ++ unsigned int uOutLen; ++ CK_AEAD_PARAMS aeadParams; ++ static const int tagSize = 16; ++ ++ param.type = siBuffer; ++ param.len = sizeof(aeadParams); ++ param.data = (unsigned char *) &aeadParams; ++ memset(&aeadParams, 0, sizeof(CK_AEAD_PARAMS)); ++ aeadParams.pIv = (unsigned char *) additionalData; ++ aeadParams.ulIvLen = 8; ++ aeadParams.pAAD = (unsigned char *) additionalData; ++ aeadParams.ulAADLen = additionalDataLen; ++ aeadParams.ulTagBits = tagSize * 8; ++ ++ if (doDecrypt) { ++ rv = pk11_decrypt(keys->write_key, CKM_NSS_CHACHA20_POLY1305, ¶m, ++ out, &uOutLen, maxout, in, inlen); ++ } else { ++ rv = pk11_encrypt(keys->write_key, CKM_NSS_CHACHA20_POLY1305, ¶m, ++ out, &uOutLen, maxout, in, inlen); ++ } ++ *outlen = (int) uOutLen; ++ ++ return rv; ++} ++ + /* Initialize encryption and MAC contexts for pending spec. + * Master Secret already is derived. + * Caller holds Spec write lock. +@@ -2053,13 +2114,17 @@ ssl3_InitPendingContextsPKCS11(sslSocket *ss) + pwSpec->client.write_mac_context = NULL; + pwSpec->server.write_mac_context = NULL; + +- if (calg == calg_aes_gcm) { ++ if (calg == calg_aes_gcm || calg == calg_chacha20) { + pwSpec->encode = NULL; + pwSpec->decode = NULL; + pwSpec->destroy = NULL; + pwSpec->encodeContext = NULL; + pwSpec->decodeContext = NULL; +- pwSpec->aead = ssl3_AESGCM; ++ if (calg == calg_aes_gcm) { ++ pwSpec->aead = ssl3_AESGCM; ++ } else { ++ pwSpec->aead = ssl3_ChaCha20Poly1305; ++ } + return SECSuccess; + } + +diff --git a/nss/lib/ssl/ssl3ecc.c b/nss/lib/ssl/ssl3ecc.c +index a3638e7..21a5e05 100644 +--- a/nss/lib/ssl/ssl3ecc.c ++++ b/nss/lib/ssl/ssl3ecc.c +@@ -913,6 +913,7 @@ static const ssl3CipherSuite ecdhe_ecdsa_suites[] = { + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, ++ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_NULL_SHA, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + 0 /* end of list marker */ +@@ -924,6 +925,7 @@ static const ssl3CipherSuite ecdhe_rsa_suites[] = { + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, ++ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_NULL_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA, + 0 /* end of list marker */ +@@ -936,6 +938,7 @@ static const ssl3CipherSuite ecSuites[] = { + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, ++ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_NULL_SHA, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, +@@ -943,6 +946,7 @@ static const ssl3CipherSuite ecSuites[] = { + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, ++ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_NULL_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA, + TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, +diff --git a/nss/lib/ssl/sslenum.c b/nss/lib/ssl/sslenum.c +index 597ec07..fc6b854 100644 +--- a/nss/lib/ssl/sslenum.c ++++ b/nss/lib/ssl/sslenum.c +@@ -31,6 +31,8 @@ + const PRUint16 SSL_ImplementedCiphers[] = { + /* AES-GCM */ + #ifdef NSS_ENABLE_ECC ++ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, ++ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + #endif /* NSS_ENABLE_ECC */ +diff --git a/nss/lib/ssl/sslimpl.h b/nss/lib/ssl/sslimpl.h +index 0fe12d0..e3ae9ce 100644 +--- a/nss/lib/ssl/sslimpl.h ++++ b/nss/lib/ssl/sslimpl.h +@@ -65,6 +65,7 @@ typedef SSLSignType SSL3SignType; + #define calg_camellia ssl_calg_camellia + #define calg_seed ssl_calg_seed + #define calg_aes_gcm ssl_calg_aes_gcm ++#define calg_chacha20 ssl_calg_chacha20 + + #define mac_null ssl_mac_null + #define mac_md5 ssl_mac_md5 +@@ -292,7 +293,7 @@ typedef struct { + } ssl3CipherSuiteCfg; + + #ifdef NSS_ENABLE_ECC +-#define ssl_V3_SUITES_IMPLEMENTED 61 ++#define ssl_V3_SUITES_IMPLEMENTED 63 + #else + #define ssl_V3_SUITES_IMPLEMENTED 37 + #endif /* NSS_ENABLE_ECC */ +@@ -474,6 +475,7 @@ typedef enum { + cipher_camellia_256, + cipher_seed, + cipher_aes_128_gcm, ++ cipher_chacha20, + cipher_missing /* reserved for no such supported cipher */ + /* This enum must match ssl3_cipherName[] in ssl3con.c. */ + } SSL3BulkCipher; +diff --git a/nss/lib/ssl/sslinfo.c b/nss/lib/ssl/sslinfo.c +index 9597209..bfc1676 100644 +--- a/nss/lib/ssl/sslinfo.c ++++ b/nss/lib/ssl/sslinfo.c +@@ -118,6 +118,7 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len) + #define C_NULL "NULL", calg_null + #define C_SJ "SKIPJACK", calg_sj + #define C_AESGCM "AES-GCM", calg_aes_gcm ++#define C_CHACHA20 "CHACHA20POLY1305", calg_chacha20 + + #define B_256 256, 256, 256 + #define B_128 128, 128, 128 +@@ -196,12 +197,14 @@ static const SSLCipherSuiteInfo suiteInfo[] = { + {0,CS(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA), S_ECDSA, K_ECDHE, C_AES, B_128, M_SHA, 1, 0, 0, }, + {0,CS(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), S_ECDSA, K_ECDHE, C_AES, B_128, M_SHA256, 1, 0, 0, }, + {0,CS(TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA), S_ECDSA, K_ECDHE, C_AES, B_256, M_SHA, 1, 0, 0, }, ++{0,CS(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305),S_ECDSA,K_ECDHE,C_CHACHA20,B_256,M_AEAD_128,0, 0, 0, }, + + {0,CS(TLS_ECDH_RSA_WITH_NULL_SHA), S_RSA, K_ECDH, C_NULL, B_0, M_SHA, 0, 0, 0, }, + {0,CS(TLS_ECDH_RSA_WITH_RC4_128_SHA), S_RSA, K_ECDH, C_RC4, B_128, M_SHA, 0, 0, 0, }, + {0,CS(TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA), S_RSA, K_ECDH, C_3DES, B_3DES, M_SHA, 1, 0, 0, }, + {0,CS(TLS_ECDH_RSA_WITH_AES_128_CBC_SHA), S_RSA, K_ECDH, C_AES, B_128, M_SHA, 1, 0, 0, }, + {0,CS(TLS_ECDH_RSA_WITH_AES_256_CBC_SHA), S_RSA, K_ECDH, C_AES, B_256, M_SHA, 1, 0, 0, }, ++{0,CS(TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305), S_RSA,K_ECDHE,C_CHACHA20,B_256,M_AEAD_128, 0, 0, 0, }, + + {0,CS(TLS_ECDHE_RSA_WITH_NULL_SHA), S_RSA, K_ECDHE, C_NULL, B_0, M_SHA, 0, 0, 0, }, + {0,CS(TLS_ECDHE_RSA_WITH_RC4_128_SHA), S_RSA, K_ECDHE, C_RC4, B_128, M_SHA, 0, 0, 0, }, +diff --git a/nss/lib/ssl/sslproto.h b/nss/lib/ssl/sslproto.h +index 53bba01..6b60a28 100644 +--- a/nss/lib/ssl/sslproto.h ++++ b/nss/lib/ssl/sslproto.h +@@ -213,6 +213,9 @@ + #define TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 0xC02F + #define TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 0xC031 + ++#define TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 0xCC13 ++#define TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 0xCC14 ++ + /* Netscape "experimental" cipher suites. */ + #define SSL_RSA_OLDFIPS_WITH_3DES_EDE_CBC_SHA 0xffe0 + #define SSL_RSA_OLDFIPS_WITH_DES_CBC_SHA 0xffe1 +diff --git a/nss/lib/ssl/sslsock.c b/nss/lib/ssl/sslsock.c +index c17c7a3..ffbccc6 100644 +--- a/nss/lib/ssl/sslsock.c ++++ b/nss/lib/ssl/sslsock.c +@@ -98,6 +98,7 @@ static cipherPolicy ssl_ciphers[] = { /* Export France */ + { TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, ++ { TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDH_RSA_WITH_NULL_SHA, SSL_ALLOWED, SSL_ALLOWED }, + { TLS_ECDH_RSA_WITH_RC4_128_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, +@@ -110,6 +111,7 @@ static cipherPolicy ssl_ciphers[] = { /* Export France */ + { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, ++ { TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + #endif /* NSS_ENABLE_ECC */ + { 0, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED } + }; +diff --git a/nss/lib/ssl/sslt.h b/nss/lib/ssl/sslt.h +index b03422e..a8007d8 100644 +--- a/nss/lib/ssl/sslt.h ++++ b/nss/lib/ssl/sslt.h +@@ -94,7 +94,8 @@ typedef enum { + ssl_calg_aes = 7, + ssl_calg_camellia = 8, + ssl_calg_seed = 9, +- ssl_calg_aes_gcm = 10 ++ ssl_calg_aes_gcm = 10, ++ ssl_calg_chacha20 = 11 + } SSLCipherAlgorithm; + + typedef enum { diff --git a/chromium/net/third_party/nss/patches/ciphersuiteversion.patch b/chromium/net/third_party/nss/patches/ciphersuiteversion.patch new file mode 100644 index 00000000000..3967f17dcd6 --- /dev/null +++ b/chromium/net/third_party/nss/patches/ciphersuiteversion.patch @@ -0,0 +1,169 @@ +diff --git a/nss/lib/ssl/ssl3con.c b/nss/lib/ssl/ssl3con.c +index bc54c99..1245393 100644 +--- a/nss/lib/ssl/ssl3con.c ++++ b/nss/lib/ssl/ssl3con.c +@@ -631,8 +631,9 @@ void SSL_AtomicIncrementLong(long * x) + } + + static PRBool +-ssl3_CipherSuiteAllowedForVersion(ssl3CipherSuite cipherSuite, +- SSL3ProtocolVersion version) ++ssl3_CipherSuiteAllowedForVersionRange( ++ ssl3CipherSuite cipherSuite, ++ const SSLVersionRange *vrange) + { + switch (cipherSuite) { + /* See RFC 4346 A.5. Export cipher suites must not be used in TLS 1.1 or +@@ -649,7 +650,9 @@ ssl3_CipherSuiteAllowedForVersion(ssl3CipherSuite cipherSuite, + * SSL_DH_ANON_EXPORT_WITH_RC4_40_MD5: never implemented + * SSL_DH_ANON_EXPORT_WITH_DES40_CBC_SHA: never implemented + */ +- return version <= SSL_LIBRARY_VERSION_TLS_1_0; ++ return vrange->min <= SSL_LIBRARY_VERSION_TLS_1_0; ++ case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: ++ case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: + case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: + case TLS_RSA_WITH_AES_256_CBC_SHA256: + case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: +@@ -661,7 +664,7 @@ ssl3_CipherSuiteAllowedForVersion(ssl3CipherSuite cipherSuite, + case TLS_RSA_WITH_AES_128_CBC_SHA256: + case TLS_RSA_WITH_AES_128_GCM_SHA256: + case TLS_RSA_WITH_NULL_SHA256: +- return version >= SSL_LIBRARY_VERSION_TLS_1_2; ++ return vrange->max >= SSL_LIBRARY_VERSION_TLS_1_2; + default: + return PR_TRUE; + } +@@ -804,7 +807,8 @@ ssl3_config_match_init(sslSocket *ss) + } + + +-/* return PR_TRUE if suite matches policy and enabled state */ ++/* return PR_TRUE if suite matches policy, enabled state and is applicable to ++ * the given version range. */ + /* It would be a REALLY BAD THING (tm) if we ever permitted the use + ** of a cipher that was NOT_ALLOWED. So, if this is ever called with + ** policy == SSL_NOT_ALLOWED, report no match. +@@ -812,7 +816,8 @@ ssl3_config_match_init(sslSocket *ss) + /* adjust suite enabled to the availability of a token that can do the + * cipher suite. */ + static PRBool +-config_match(ssl3CipherSuiteCfg *suite, int policy, PRBool enabled) ++config_match(ssl3CipherSuiteCfg *suite, int policy, PRBool enabled, ++ const SSLVersionRange *vrange) + { + PORT_Assert(policy != SSL_NOT_ALLOWED && enabled != PR_FALSE); + if (policy == SSL_NOT_ALLOWED || !enabled) +@@ -820,10 +825,13 @@ config_match(ssl3CipherSuiteCfg *suite, int policy, PRBool enabled) + return (PRBool)(suite->enabled && + suite->isPresent && + suite->policy != SSL_NOT_ALLOWED && +- suite->policy <= policy); ++ suite->policy <= policy && ++ ssl3_CipherSuiteAllowedForVersionRange( ++ suite->cipher_suite, vrange)); + } + +-/* return number of cipher suites that match policy and enabled state */ ++/* return number of cipher suites that match policy, enabled state and are ++ * applicable for the configured protocol version range. */ + /* called from ssl3_SendClientHello and ssl3_ConstructV2CipherSpecsHack */ + static int + count_cipher_suites(sslSocket *ss, int policy, PRBool enabled) +@@ -834,7 +842,7 @@ count_cipher_suites(sslSocket *ss, int policy, PRBool enabled) + return 0; + } + for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { +- if (config_match(&ss->cipherSuites[i], policy, enabled)) ++ if (config_match(&ss->cipherSuites[i], policy, enabled, &ss->vrange)) + count++; + } + if (count <= 0) { +@@ -5294,7 +5302,7 @@ ssl3_SendClientHello(sslSocket *ss, PRBool resending) + } + for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { + ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i]; +- if (config_match(suite, ss->ssl3.policy, PR_TRUE)) { ++ if (config_match(suite, ss->ssl3.policy, PR_TRUE, &ss->vrange)) { + actual_count++; + if (actual_count > num_suites) { + /* set error card removal/insertion error */ +@@ -6359,15 +6367,19 @@ ssl3_HandleServerHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { + ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i]; + if (temp == suite->cipher_suite) { +- if (!config_match(suite, ss->ssl3.policy, PR_TRUE)) { ++ SSLVersionRange vrange = {ss->version, ss->version}; ++ if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) { ++ /* config_match already checks whether the cipher suite is ++ * acceptable for the version, but the check is repeated here ++ * in order to give a more precise error code. */ ++ if (!ssl3_CipherSuiteAllowedForVersionRange(temp, &vrange)) { ++ desc = handshake_failure; ++ errCode = SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION; ++ goto alert_loser; ++ } ++ + break; /* failure */ + } +- if (!ssl3_CipherSuiteAllowedForVersion(suite->cipher_suite, +- ss->version)) { +- desc = handshake_failure; +- errCode = SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION; +- goto alert_loser; +- } + + suite_found = PR_TRUE; + break; /* success */ +@@ -8008,6 +8020,9 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + */ + if (sid) do { + ssl3CipherSuiteCfg *suite; ++#ifdef PARANOID ++ SSLVersionRange vrange = {ss->version, ss->version}; ++#endif + + /* Check that the cached compression method is still enabled. */ + if (!compressionEnabled(ss, sid->u.ssl3.compression)) +@@ -8036,7 +8051,7 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + * The product policy won't change during the process lifetime. + * Implemented ("isPresent") shouldn't change for servers. + */ +- if (!config_match(suite, ss->ssl3.policy, PR_TRUE)) ++ if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) + break; + #else + if (!suite->enabled) +@@ -8084,9 +8099,8 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) + */ + for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) { + ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j]; +- if (!config_match(suite, ss->ssl3.policy, PR_TRUE) || +- !ssl3_CipherSuiteAllowedForVersion(suite->cipher_suite, +- ss->version)) { ++ SSLVersionRange vrange = {ss->version, ss->version}; ++ if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) { + continue; + } + for (i = 0; i + 1 < suites.len; i += 2) { +@@ -8619,9 +8633,8 @@ ssl3_HandleV2ClientHello(sslSocket *ss, unsigned char *buffer, int length) + */ + for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) { + ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j]; +- if (!config_match(suite, ss->ssl3.policy, PR_TRUE) || +- !ssl3_CipherSuiteAllowedForVersion(suite->cipher_suite, +- ss->version)) { ++ SSLVersionRange vrange = {ss->version, ss->version}; ++ if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) { + continue; + } + for (i = 0; i+2 < suite_length; i += 3) { +@@ -12324,7 +12337,7 @@ ssl3_ConstructV2CipherSpecsHack(sslSocket *ss, unsigned char *cs, int *size) + /* ssl3_config_match_init was called by the caller of this function. */ + for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { + ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i]; +- if (config_match(suite, SSL_ALLOWED, PR_TRUE)) { ++ if (config_match(suite, SSL_ALLOWED, PR_TRUE, &ss->vrange)) { + if (cs != NULL) { + *cs++ = 0x00; + *cs++ = (suite->cipher_suite >> 8) & 0xFF; diff --git a/chromium/net/third_party/nss/patches/nullcipher_934016.patch b/chromium/net/third_party/nss/patches/nullcipher_934016.patch new file mode 100644 index 00000000000..6a4b5c68d26 --- /dev/null +++ b/chromium/net/third_party/nss/patches/nullcipher_934016.patch @@ -0,0 +1,16 @@ +diff --git a/net/third_party/nss/ssl/ssl3con.c b/net/third_party/nss/ssl/ssl3con.c +index 8395f61..8b8b758 100644 +--- a/net/third_party/nss/ssl/ssl3con.c ++++ b/net/third_party/nss/ssl/ssl3con.c +@@ -859,6 +859,11 @@ static SECStatus + Null_Cipher(void *ctx, unsigned char *output, int *outputLen, int maxOutputLen, + const unsigned char *input, int inputLen) + { ++ if (inputLen > maxOutputLen) { ++ *outputLen = 0; /* Match PK11_CipherOp in setting outputLen */ ++ PORT_SetError(SEC_ERROR_OUTPUT_LEN); ++ return SECFailure; ++ } + *outputLen = inputLen; + if (input != output) + PORT_Memcpy(output, input, inputLen); diff --git a/chromium/net/third_party/nss/patches/peercertchain2.patch b/chromium/net/third_party/nss/patches/peercertchain2.patch new file mode 100644 index 00000000000..4b4a4fb5fa7 --- /dev/null +++ b/chromium/net/third_party/nss/patches/peercertchain2.patch @@ -0,0 +1,107 @@ +Index: net/third_party/nss/ssl/ssl.h +=================================================================== +--- net/third_party/nss/ssl/ssl.h (revision 225295) ++++ net/third_party/nss/ssl/ssl.h (working copy) +@@ -434,6 +434,15 @@ + */ + SSL_IMPORT CERTCertificate *SSL_PeerCertificate(PRFileDesc *fd); + ++/* ++** Return the certificates presented by the SSL peer. If the SSL peer ++** did not present certificates, return NULL with the ++** SSL_ERROR_NO_CERTIFICATE error. On failure, return NULL with an error ++** code other than SSL_ERROR_NO_CERTIFICATE. ++** "fd" the socket "file" descriptor ++*/ ++SSL_IMPORT CERTCertList *SSL_PeerCertificateChain(PRFileDesc *fd); ++ + /* SSL_PeerStapledOCSPResponses returns the OCSP responses that were provided + * by the TLS server. The return value is a pointer to an internal SECItemArray + * that contains the returned OCSP responses; it is only valid until the +@@ -463,18 +472,6 @@ + SSLKEAType kea); + + /* +-** Return references to the certificates presented by the SSL peer. +-** |maxNumCerts| must contain the size of the |certs| array. On successful +-** return, |*numCerts| contains the number of certificates available and +-** |certs| will contain references to as many certificates as would fit. +-** Therefore if |*numCerts| contains a value less than or equal to +-** |maxNumCerts|, then all certificates were returned. +-*/ +-SSL_IMPORT SECStatus SSL_PeerCertificateChain( +- PRFileDesc *fd, CERTCertificate **certs, +- unsigned int *numCerts, unsigned int maxNumCerts); +- +-/* + ** Authenticate certificate hook. Called when a certificate comes in + ** (because of SSL_REQUIRE_CERTIFICATE in SSL_Enable) to authenticate the + ** certificate. +Index: net/third_party/nss/ssl/sslauth.c +=================================================================== +--- net/third_party/nss/ssl/sslauth.c (revision 225295) ++++ net/third_party/nss/ssl/sslauth.c (working copy) +@@ -28,38 +28,43 @@ + } + + /* NEED LOCKS IN HERE. */ +-SECStatus +-SSL_PeerCertificateChain(PRFileDesc *fd, CERTCertificate **certs, +- unsigned int *numCerts, unsigned int maxNumCerts) ++CERTCertList * ++SSL_PeerCertificateChain(PRFileDesc *fd) + { + sslSocket *ss; +- ssl3CertNode* cur; ++ CERTCertList *chain = NULL; ++ CERTCertificate *cert; ++ ssl3CertNode *cur; + + ss = ssl_FindSocket(fd); + if (!ss) { + SSL_DBG(("%d: SSL[%d]: bad socket in PeerCertificateChain", + SSL_GETPID(), fd)); +- return SECFailure; ++ return NULL; + } +- if (!ss->opt.useSecurity) +- return SECFailure; +- +- if (ss->sec.peerCert == NULL) { +- *numCerts = 0; +- return SECSuccess; ++ if (!ss->opt.useSecurity || !ss->sec.peerCert) { ++ PORT_SetError(SSL_ERROR_NO_CERTIFICATE); ++ return NULL; + } +- +- *numCerts = 1; /* for the leaf certificate */ +- if (maxNumCerts > 0) +- certs[0] = CERT_DupCertificate(ss->sec.peerCert); +- ++ chain = CERT_NewCertList(); ++ if (!chain) { ++ return NULL; ++ } ++ cert = CERT_DupCertificate(ss->sec.peerCert); ++ if (CERT_AddCertToListTail(chain, cert) != SECSuccess) { ++ goto loser; ++ } + for (cur = ss->ssl3.peerCertChain; cur; cur = cur->next) { +- if (*numCerts < maxNumCerts) +- certs[*numCerts] = CERT_DupCertificate(cur->cert); +- (*numCerts)++; ++ cert = CERT_DupCertificate(cur->cert); ++ if (CERT_AddCertToListTail(chain, cert) != SECSuccess) { ++ goto loser; ++ } + } ++ return chain; + +- return SECSuccess; ++loser: ++ CERT_DestroyCertList(chain); ++ return NULL; + } + + /* NEED LOCKS IN HERE. */ diff --git a/chromium/net/third_party/nss/patches/resumeclienthelloversion.patch b/chromium/net/third_party/nss/patches/resumeclienthelloversion.patch new file mode 100644 index 00000000000..7c330c71fff --- /dev/null +++ b/chromium/net/third_party/nss/patches/resumeclienthelloversion.patch @@ -0,0 +1,31 @@ +diff --git a/nss/lib/ssl/ssl3con.c b/nss/lib/ssl/ssl3con.c +index d22a7d6..a7617fb 100644 +--- a/nss/lib/ssl/ssl3con.c ++++ b/nss/lib/ssl/ssl3con.c +@@ -2865,12 +2865,14 @@ ssl3_CompressMACEncryptRecord(ssl3CipherSpec * cwSpec, + * Forces the use of the provided epoch + * ssl_SEND_FLAG_CAP_RECORD_VERSION + * Caps the record layer version number of TLS ClientHello to { 3, 1 } +- * (TLS 1.0). Some TLS 1.0 servers (which seem to use F5 BIG-IP) ignore ++ * (TLS 1.0). Some TLS 1.0 servers (which seem to use F5 BIG-IP) ignore + * ClientHello.client_version and use the record layer version number + * (TLSPlaintext.version) instead when negotiating protocol versions. In + * addition, if the record layer version number of ClientHello is { 3, 2 } +- * (TLS 1.1) or higher, these servers reset the TCP connections. Set this +- * flag to work around such servers. ++ * (TLS 1.1) or higher, these servers reset the TCP connections. Lastly, ++ * some F5 BIG-IP servers hang if a record containing a ClientHello has a ++ * version greater than 0x0301 and a length greater than 255. Set this flag ++ * to work around such servers. + */ + PRInt32 + ssl3_SendRecord( sslSocket * ss, +@@ -5363,7 +5365,7 @@ ssl3_SendClientHello(sslSocket *ss, PRBool resending) + } + + flags = 0; +- if (!ss->firstHsDone && !requestingResume && !IS_DTLS(ss)) { ++ if (!ss->firstHsDone && !IS_DTLS(ss)) { + flags |= ssl_SEND_FLAG_CAP_RECORD_VERSION; + } + rv = ssl3_FlushHandshake(ss, flags); diff --git a/chromium/net/third_party/nss/ssl.gyp b/chromium/net/third_party/nss/ssl.gyp index fc526733c02..986b563ebf8 100644 --- a/chromium/net/third_party/nss/ssl.gyp +++ b/chromium/net/third_party/nss/ssl.gyp @@ -108,6 +108,13 @@ '-Wno-header-guard', ], }], + [ 'OS == "linux"', { + 'link_settings': { + 'libraries': [ + '-ldl', + ], + }, + }], [ 'OS == "mac" or OS == "ios"', { 'defines': [ 'XP_UNIX', diff --git a/chromium/net/third_party/nss/ssl/ssl.h b/chromium/net/third_party/nss/ssl/ssl.h index c083a6b2a60..47468a0a289 100644 --- a/chromium/net/third_party/nss/ssl/ssl.h +++ b/chromium/net/third_party/nss/ssl/ssl.h @@ -121,14 +121,22 @@ SSL_IMPORT PRFileDesc *DTLS_ImportFD(PRFileDesc *model, PRFileDesc *fd); #define SSL_ENABLE_FALSE_START 22 /* Enable SSL false start (off by */ /* default, applies only to */ /* clients). False start is a */ -/* mode where an SSL client will start sending application data before */ -/* verifying the server's Finished message. This means that we could end up */ -/* sending data to an imposter. However, the data will be encrypted and */ -/* only the true server can derive the session key. Thus, so long as the */ -/* cipher isn't broken this is safe. Because of this, False Start will only */ -/* occur on RSA or DH ciphersuites where the cipher's key length is >= 80 */ -/* bits. The advantage of False Start is that it saves a round trip for */ -/* client-speaks-first protocols when performing a full handshake. */ +/* mode where an SSL client will start sending application data before + * verifying the server's Finished message. This means that we could end up + * sending data to an imposter. However, the data will be encrypted and + * only the true server can derive the session key. Thus, so long as the + * cipher isn't broken this is safe. The advantage of false start is that + * it saves a round trip for client-speaks-first protocols when performing a + * full handshake. + * + * See SSL_DefaultCanFalseStart for the default criteria that NSS uses to + * determine whether to false start or not. See SSL_SetCanFalseStartCallback + * for how to change that criteria. In addition to those criteria, false start + * will only be done when the server selects a cipher suite with an effective + * key length of 80 bits or more (including RC4-128). Also, see + * SSL_HandshakeCallback for a description on how false start affects when the + * handshake callback gets called. + */ /* For SSL 3.0 and TLS 1.0, by default we prevent chosen plaintext attacks * on SSL CBC mode cipher suites (see RFC 4346 Section F.3) by splitting @@ -434,6 +442,15 @@ SSL_IMPORT SECStatus SSL_SecurityStatus(PRFileDesc *fd, int *on, char **cipher, */ SSL_IMPORT CERTCertificate *SSL_PeerCertificate(PRFileDesc *fd); +/* +** Return the certificates presented by the SSL peer. If the SSL peer +** did not present certificates, return NULL with the +** SSL_ERROR_NO_CERTIFICATE error. On failure, return NULL with an error +** code other than SSL_ERROR_NO_CERTIFICATE. +** "fd" the socket "file" descriptor +*/ +SSL_IMPORT CERTCertList *SSL_PeerCertificateChain(PRFileDesc *fd); + /* SSL_PeerStapledOCSPResponses returns the OCSP responses that were provided * by the TLS server. The return value is a pointer to an internal SECItemArray * that contains the returned OCSP responses; it is only valid until the @@ -463,18 +480,6 @@ SSL_SetStapledOCSPResponses(PRFileDesc *fd, const SECItemArray *responses, SSLKEAType kea); /* -** Return references to the certificates presented by the SSL peer. -** |maxNumCerts| must contain the size of the |certs| array. On successful -** return, |*numCerts| contains the number of certificates available and -** |certs| will contain references to as many certificates as would fit. -** Therefore if |*numCerts| contains a value less than or equal to -** |maxNumCerts|, then all certificates were returned. -*/ -SSL_IMPORT SECStatus SSL_PeerCertificateChain( - PRFileDesc *fd, CERTCertificate **certs, - unsigned int *numCerts, unsigned int maxNumCerts); - -/* ** Authenticate certificate hook. Called when a certificate comes in ** (because of SSL_REQUIRE_CERTIFICATE in SSL_Enable) to authenticate the ** certificate. @@ -744,14 +749,59 @@ SSL_IMPORT SECStatus SSL_SetMaxServerCacheLocks(PRUint32 maxLocks); SSL_IMPORT SECStatus SSL_InheritMPServerSIDCache(const char * envString); /* -** Set the callback on a particular socket that gets called when we finish -** performing a handshake. +** Set the callback that normally gets called when the TLS handshake +** is complete. If false start is not enabled, then the handshake callback is +** called after verifying the peer's Finished message and before sending +** outgoing application data and before processing incoming application data. +** +** If false start is enabled and there is a custom CanFalseStartCallback +** callback set, then the handshake callback gets called after the peer's +** Finished message has been verified, which may be after application data is +** sent. +** +** If false start is enabled and there is not a custom CanFalseStartCallback +** callback established with SSL_SetCanFalseStartCallback then the handshake +** callback gets called before any application data is sent, which may be +** before the peer's Finished message has been verified. */ typedef void (PR_CALLBACK *SSLHandshakeCallback)(PRFileDesc *fd, void *client_data); SSL_IMPORT SECStatus SSL_HandshakeCallback(PRFileDesc *fd, SSLHandshakeCallback cb, void *client_data); +/* Applications that wish to customize TLS false start should set this callback +** function. NSS will invoke the functon to determine if a particular +** connection should use false start or not. SECSuccess indicates that the +** callback completed successfully, and if so *canFalseStart indicates if false +** start can be used. If the callback does not return SECSuccess then the +** handshake will be canceled. +** +** Applications that do not set the callback will use an internal set of +** criteria to determine if the connection should false start. If +** the callback is set false start will never be used without invoking the +** callback function, but some connections (e.g. resumed connections) will +** never use false start and therefore will not invoke the callback. +** +** NSS's internal criteria for this connection can be evaluated by calling +** SSL_DefaultCanFalseStart() from the custom callback. +** +** See the description of SSL_HandshakeCallback for important information on +** how registering a custom false start callback affects when the handshake +** callback gets called. +**/ +typedef SECStatus (PR_CALLBACK *SSLCanFalseStartCallback)( + PRFileDesc *fd, void *arg, PRBool *canFalseStart); + +SSL_IMPORT SECStatus SSL_SetCanFalseStartCallback( + PRFileDesc *fd, SSLCanFalseStartCallback callback, void *arg); + +/* A utility function that can be called from a custom CanFalseStartCallback +** function to determine what NSS would have done for this connection if the +** custom callback was not implemented. +**/ +SSL_IMPORT SECStatus SSL_DefaultCanFalseStart(PRFileDesc *fd, + PRBool *canFalseStart); + /* ** For the server, request a new handshake. For the client, begin a new ** handshake. If flushCache is non-zero, the SSL3 cache entry will be diff --git a/chromium/net/third_party/nss/ssl/ssl3con.c b/chromium/net/third_party/nss/ssl/ssl3con.c index bf4ff6b716c..8b8b758c0b4 100644 --- a/chromium/net/third_party/nss/ssl/ssl3con.c +++ b/chromium/net/third_party/nss/ssl/ssl3con.c @@ -40,10 +40,28 @@ #define CKM_NSS_TLS_MASTER_KEY_DERIVE_DH_SHA256 (CKM_NSS + 24) #endif +/* This is a bodge to allow this code to be compiled against older NSS + * headers. */ +#ifndef CKM_NSS_CHACHA20_POLY1305 +#define CKM_NSS_CHACHA20_POLY1305 (CKM_NSS + 26) + +typedef struct CK_NSS_AEAD_PARAMS { + CK_BYTE_PTR pIv; /* This is the nonce. */ + CK_ULONG ulIvLen; + CK_BYTE_PTR pAAD; + CK_ULONG ulAADLen; + CK_ULONG ulTagLen; +} CK_NSS_AEAD_PARAMS; + +#endif + #include <stdio.h> #ifdef NSS_ENABLE_ZLIB #include "zlib.h" #endif +#ifdef LINUX +#include <dlfcn.h> +#endif #ifndef PK11_SETATTRS #define PK11_SETATTRS(x,id,v,l) (x)->type = (id); \ @@ -78,6 +96,13 @@ static int ssl3_OIDToTLSHashAlgorithm(SECOidTag oid); static SECStatus Null_Cipher(void *ctx, unsigned char *output, int *outputLen, int maxOutputLen, const unsigned char *input, int inputLen); +#ifndef NO_PKCS11_BYPASS +static SECStatus ssl3_AESGCMBypass(ssl3KeyMaterial *keys, PRBool doDecrypt, + unsigned char *out, int *outlen, int maxout, + const unsigned char *in, int inlen, + const unsigned char *additionalData, + int additionalDataLen); +#endif #define MAX_SEND_BUF_LENGTH 32000 /* watch for 16-bit integer overflow */ #define MIN_SEND_BUF_LENGTH 4000 @@ -90,6 +115,15 @@ static SECStatus Null_Cipher(void *ctx, unsigned char *output, int *outputLen, static ssl3CipherSuiteCfg cipherSuites[ssl_V3_SUITES_IMPLEMENTED] = { /* cipher_suite policy enabled is_present*/ #ifdef NSS_ENABLE_ECC + { TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + { TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, + { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, +#endif /* NSS_ENABLE_ECC */ + { TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_TRUE,PR_FALSE}, + { TLS_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, PR_TRUE,PR_FALSE}, + +#ifdef NSS_ENABLE_ECC { TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, { TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, PR_FALSE,PR_FALSE}, #endif /* NSS_ENABLE_ECC */ @@ -233,23 +267,31 @@ static SSL3Statistics ssl3stats; /* indexed by SSL3BulkCipher */ static const ssl3BulkCipherDef bulk_cipher_defs[] = { - /* cipher calg keySz secretSz type ivSz BlkSz keygen */ - {cipher_null, calg_null, 0, 0, type_stream, 0, 0, kg_null}, - {cipher_rc4, calg_rc4, 16, 16, type_stream, 0, 0, kg_strong}, - {cipher_rc4_40, calg_rc4, 16, 5, type_stream, 0, 0, kg_export}, - {cipher_rc4_56, calg_rc4, 16, 7, type_stream, 0, 0, kg_export}, - {cipher_rc2, calg_rc2, 16, 16, type_block, 8, 8, kg_strong}, - {cipher_rc2_40, calg_rc2, 16, 5, type_block, 8, 8, kg_export}, - {cipher_des, calg_des, 8, 8, type_block, 8, 8, kg_strong}, - {cipher_3des, calg_3des, 24, 24, type_block, 8, 8, kg_strong}, - {cipher_des40, calg_des, 8, 5, type_block, 8, 8, kg_export}, - {cipher_idea, calg_idea, 16, 16, type_block, 8, 8, kg_strong}, - {cipher_aes_128, calg_aes, 16, 16, type_block, 16,16, kg_strong}, - {cipher_aes_256, calg_aes, 32, 32, type_block, 16,16, kg_strong}, - {cipher_camellia_128, calg_camellia,16, 16, type_block, 16,16, kg_strong}, - {cipher_camellia_256, calg_camellia,32, 32, type_block, 16,16, kg_strong}, - {cipher_seed, calg_seed, 16, 16, type_block, 16,16, kg_strong}, - {cipher_missing, calg_null, 0, 0, type_stream, 0, 0, kg_null}, + /* |--------- Lengths --------| */ + /* cipher calg k s type i b t n */ + /* e e v l a o */ + /* y c | o g n */ + /* | r | c | c */ + /* | e | k | e */ + /* | t | | | | */ + {cipher_null, calg_null, 0, 0, type_stream, 0, 0, 0, 0}, + {cipher_rc4, calg_rc4, 16,16, type_stream, 0, 0, 0, 0}, + {cipher_rc4_40, calg_rc4, 16, 5, type_stream, 0, 0, 0, 0}, + {cipher_rc4_56, calg_rc4, 16, 7, type_stream, 0, 0, 0, 0}, + {cipher_rc2, calg_rc2, 16,16, type_block, 8, 8, 0, 0}, + {cipher_rc2_40, calg_rc2, 16, 5, type_block, 8, 8, 0, 0}, + {cipher_des, calg_des, 8, 8, type_block, 8, 8, 0, 0}, + {cipher_3des, calg_3des, 24,24, type_block, 8, 8, 0, 0}, + {cipher_des40, calg_des, 8, 5, type_block, 8, 8, 0, 0}, + {cipher_idea, calg_idea, 16,16, type_block, 8, 8, 0, 0}, + {cipher_aes_128, calg_aes, 16,16, type_block, 16,16, 0, 0}, + {cipher_aes_256, calg_aes, 32,32, type_block, 16,16, 0, 0}, + {cipher_camellia_128, calg_camellia, 16,16, type_block, 16,16, 0, 0}, + {cipher_camellia_256, calg_camellia, 32,32, type_block, 16,16, 0, 0}, + {cipher_seed, calg_seed, 16,16, type_block, 16,16, 0, 0}, + {cipher_aes_128_gcm, calg_aes_gcm, 16,16, type_aead, 4, 0,16, 8}, + {cipher_chacha20, calg_chacha20, 32,32, type_aead, 0, 0,16, 0}, + {cipher_missing, calg_null, 0, 0, type_stream, 0, 0, 0, 0}, }; static const ssl3KEADef kea_defs[] = @@ -371,6 +413,13 @@ static const ssl3CipherSuiteDef cipher_suite_defs[] = {SSL_RSA_FIPS_WITH_3DES_EDE_CBC_SHA, cipher_3des, mac_sha, kea_rsa_fips}, {SSL_RSA_FIPS_WITH_DES_CBC_SHA, cipher_des, mac_sha, kea_rsa_fips}, + {TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_dhe_rsa}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_rsa}, + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_ecdhe_rsa}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, cipher_aes_128_gcm, mac_aead, kea_ecdhe_ecdsa}, + {TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, cipher_chacha20, mac_aead, kea_ecdhe_rsa}, + {TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, cipher_chacha20, mac_aead, kea_ecdhe_ecdsa}, + #ifdef NSS_ENABLE_ECC {TLS_ECDH_ECDSA_WITH_NULL_SHA, cipher_null, mac_sha, kea_ecdh_ecdsa}, {TLS_ECDH_ECDSA_WITH_RC4_128_SHA, cipher_rc4, mac_sha, kea_ecdh_ecdsa}, @@ -434,25 +483,30 @@ static const SSLCipher2Mech alg2Mech[] = { { calg_aes , CKM_AES_CBC }, { calg_camellia , CKM_CAMELLIA_CBC }, { calg_seed , CKM_SEED_CBC }, + { calg_aes_gcm , CKM_AES_GCM }, + { calg_chacha20 , CKM_NSS_CHACHA20_POLY1305 }, /* { calg_init , (CK_MECHANISM_TYPE)0x7fffffffL } */ }; -#define mmech_null (CK_MECHANISM_TYPE)0x80000000L +#define mmech_invalid (CK_MECHANISM_TYPE)0x80000000L #define mmech_md5 CKM_SSL3_MD5_MAC #define mmech_sha CKM_SSL3_SHA1_MAC #define mmech_md5_hmac CKM_MD5_HMAC #define mmech_sha_hmac CKM_SHA_1_HMAC #define mmech_sha256_hmac CKM_SHA256_HMAC +#define mmech_sha384_hmac CKM_SHA384_HMAC +#define mmech_sha512_hmac CKM_SHA512_HMAC static const ssl3MACDef mac_defs[] = { /* indexed by SSL3MACAlgorithm */ /* pad_size is only used for SSL 3.0 MAC. See RFC 6101 Sec. 5.2.3.1. */ /* mac mmech pad_size mac_size */ - { mac_null, mmech_null, 0, 0 }, + { mac_null, mmech_invalid, 0, 0 }, { mac_md5, mmech_md5, 48, MD5_LENGTH }, { mac_sha, mmech_sha, 40, SHA1_LENGTH}, {hmac_md5, mmech_md5_hmac, 0, MD5_LENGTH }, {hmac_sha, mmech_sha_hmac, 0, SHA1_LENGTH}, {hmac_sha256, mmech_sha256_hmac, 0, SHA256_LENGTH}, + { mac_aead, mmech_invalid, 0, 0 }, }; /* indexed by SSL3BulkCipher */ @@ -472,6 +526,7 @@ const char * const ssl3_cipherName[] = { "Camellia-128", "Camellia-256", "SEED-CBC", + "AES-128-GCM", "missing" }; @@ -576,8 +631,9 @@ void SSL_AtomicIncrementLong(long * x) } static PRBool -ssl3_CipherSuiteAllowedForVersion(ssl3CipherSuite cipherSuite, - SSL3ProtocolVersion version) +ssl3_CipherSuiteAllowedForVersionRange( + ssl3CipherSuite cipherSuite, + const SSLVersionRange *vrange) { switch (cipherSuite) { /* See RFC 4346 A.5. Export cipher suites must not be used in TLS 1.1 or @@ -594,15 +650,21 @@ ssl3_CipherSuiteAllowedForVersion(ssl3CipherSuite cipherSuite, * SSL_DH_ANON_EXPORT_WITH_RC4_40_MD5: never implemented * SSL_DH_ANON_EXPORT_WITH_DES40_CBC_SHA: never implemented */ - return version <= SSL_LIBRARY_VERSION_TLS_1_0; + return vrange->min <= SSL_LIBRARY_VERSION_TLS_1_0; + case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: + case TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA256: case TLS_RSA_WITH_AES_256_CBC_SHA256: case TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: + case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: + case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_DHE_RSA_WITH_AES_128_CBC_SHA256: + case TLS_DHE_RSA_WITH_AES_128_GCM_SHA256: case TLS_RSA_WITH_AES_128_CBC_SHA256: + case TLS_RSA_WITH_AES_128_GCM_SHA256: case TLS_RSA_WITH_NULL_SHA256: - return version >= SSL_LIBRARY_VERSION_TLS_1_2; + return vrange->max >= SSL_LIBRARY_VERSION_TLS_1_2; default: return PR_TRUE; } @@ -745,7 +807,8 @@ ssl3_config_match_init(sslSocket *ss) } -/* return PR_TRUE if suite matches policy and enabled state */ +/* return PR_TRUE if suite matches policy, enabled state and is applicable to + * the given version range. */ /* It would be a REALLY BAD THING (tm) if we ever permitted the use ** of a cipher that was NOT_ALLOWED. So, if this is ever called with ** policy == SSL_NOT_ALLOWED, report no match. @@ -753,7 +816,8 @@ ssl3_config_match_init(sslSocket *ss) /* adjust suite enabled to the availability of a token that can do the * cipher suite. */ static PRBool -config_match(ssl3CipherSuiteCfg *suite, int policy, PRBool enabled) +config_match(ssl3CipherSuiteCfg *suite, int policy, PRBool enabled, + const SSLVersionRange *vrange) { PORT_Assert(policy != SSL_NOT_ALLOWED && enabled != PR_FALSE); if (policy == SSL_NOT_ALLOWED || !enabled) @@ -761,10 +825,13 @@ config_match(ssl3CipherSuiteCfg *suite, int policy, PRBool enabled) return (PRBool)(suite->enabled && suite->isPresent && suite->policy != SSL_NOT_ALLOWED && - suite->policy <= policy); + suite->policy <= policy && + ssl3_CipherSuiteAllowedForVersionRange( + suite->cipher_suite, vrange)); } -/* return number of cipher suites that match policy and enabled state */ +/* return number of cipher suites that match policy, enabled state and are + * applicable for the configured protocol version range. */ /* called from ssl3_SendClientHello and ssl3_ConstructV2CipherSpecsHack */ static int count_cipher_suites(sslSocket *ss, int policy, PRBool enabled) @@ -775,7 +842,7 @@ count_cipher_suites(sslSocket *ss, int policy, PRBool enabled) return 0; } for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { - if (config_match(&ss->cipherSuites[i], policy, enabled)) + if (config_match(&ss->cipherSuites[i], policy, enabled, &ss->vrange)) count++; } if (count <= 0) { @@ -792,6 +859,11 @@ static SECStatus Null_Cipher(void *ctx, unsigned char *output, int *outputLen, int maxOutputLen, const unsigned char *input, int inputLen) { + if (inputLen > maxOutputLen) { + *outputLen = 0; /* Match PK11_CipherOp in setting outputLen */ + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + return SECFailure; + } *outputLen = inputLen; if (input != output) PORT_Memcpy(output, input, inputLen); @@ -1360,7 +1432,7 @@ ssl3_SetupPendingCipherSpec(sslSocket *ss) cipher = suite_def->bulk_cipher_alg; kea = suite_def->key_exchange_alg; mac = suite_def->mac_alg; - if (mac <= ssl_mac_sha && isTLS) + if (mac <= ssl_mac_sha && mac != ssl_mac_null && isTLS) mac += 2; ss->ssl3.hs.suite_def = suite_def; @@ -1554,7 +1626,6 @@ ssl3_InitPendingContextsBypass(sslSocket *ss) unsigned int optArg2 = 0; PRBool server_encrypts = ss->sec.isServer; SSLCipherAlgorithm calg; - SSLCompressionMethod compression_method; SECStatus rv; PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); @@ -1565,7 +1636,17 @@ ssl3_InitPendingContextsBypass(sslSocket *ss) cipher_def = pwSpec->cipher_def; calg = cipher_def->calg; - compression_method = pwSpec->compression_method; + + if (calg == calg_aes_gcm) { + pwSpec->encode = NULL; + pwSpec->decode = NULL; + pwSpec->destroy = NULL; + pwSpec->encodeContext = NULL; + pwSpec->decodeContext = NULL; + pwSpec->aead = ssl3_AESGCMBypass; + ssl3_InitCompressionContext(pwSpec); + return SECSuccess; + } serverContext = pwSpec->server.cipher_context; clientContext = pwSpec->client.cipher_context; @@ -1721,6 +1802,298 @@ ssl3_ParamFromIV(CK_MECHANISM_TYPE mtype, SECItem *iv, CK_ULONG ulEffectiveBits) return param; } +/* ssl3_BuildRecordPseudoHeader writes the SSL/TLS pseudo-header (the data + * which is included in the MAC or AEAD additional data) to |out| and returns + * its length. See https://tools.ietf.org/html/rfc5246#section-6.2.3.3 for the + * definition of the AEAD additional data. + * + * TLS pseudo-header includes the record's version field, SSL's doesn't. Which + * pseudo-header defintiion to use should be decided based on the version of + * the protocol that was negotiated when the cipher spec became current, NOT + * based on the version value in the record itself, and the decision is passed + * to this function as the |includesVersion| argument. But, the |version| + * argument should be the record's version value. + */ +static unsigned int +ssl3_BuildRecordPseudoHeader(unsigned char *out, + SSL3SequenceNumber seq_num, + SSL3ContentType type, + PRBool includesVersion, + SSL3ProtocolVersion version, + PRBool isDTLS, + int length) +{ + out[0] = (unsigned char)(seq_num.high >> 24); + out[1] = (unsigned char)(seq_num.high >> 16); + out[2] = (unsigned char)(seq_num.high >> 8); + out[3] = (unsigned char)(seq_num.high >> 0); + out[4] = (unsigned char)(seq_num.low >> 24); + out[5] = (unsigned char)(seq_num.low >> 16); + out[6] = (unsigned char)(seq_num.low >> 8); + out[7] = (unsigned char)(seq_num.low >> 0); + out[8] = type; + + /* SSL3 MAC doesn't include the record's version field. */ + if (!includesVersion) { + out[9] = MSB(length); + out[10] = LSB(length); + return 11; + } + + /* TLS MAC and AEAD additional data include version. */ + if (isDTLS) { + SSL3ProtocolVersion dtls_version; + + dtls_version = dtls_TLSVersionToDTLSVersion(version); + out[9] = MSB(dtls_version); + out[10] = LSB(dtls_version); + } else { + out[9] = MSB(version); + out[10] = LSB(version); + } + out[11] = MSB(length); + out[12] = LSB(length); + return 13; +} + +typedef SECStatus (*PK11CryptFcn)( + PK11SymKey *symKey, CK_MECHANISM_TYPE mechanism, SECItem *param, + unsigned char *out, unsigned int *outLen, unsigned int maxLen, + const unsigned char *in, unsigned int inLen); + +static PK11CryptFcn pk11_encrypt = NULL; +static PK11CryptFcn pk11_decrypt = NULL; + +static PRCallOnceType resolvePK11CryptOnce; + +static PRStatus +ssl3_ResolvePK11CryptFunctions(void) +{ +#ifdef LINUX + /* On Linux we use the system NSS libraries. Look up the PK11_Encrypt and + * PK11_Decrypt functions at run time. */ + void *handle = dlopen(NULL, RTLD_LAZY); + if (!handle) { + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + return PR_FAILURE; + } + pk11_encrypt = (PK11CryptFcn)dlsym(handle, "PK11_Encrypt"); + pk11_decrypt = (PK11CryptFcn)dlsym(handle, "PK11_Decrypt"); + dlclose(handle); + return PR_SUCCESS; +#else + /* On other platforms we use our own copy of NSS. PK11_Encrypt and + * PK11_Decrypt are known to be available. */ + pk11_encrypt = PK11_Encrypt; + pk11_decrypt = PK11_Decrypt; + return PR_SUCCESS; +#endif +} + +/* + * In NSS 3.15, PK11_Encrypt and PK11_Decrypt were added to provide access + * to the AES GCM implementation in the NSS softoken. So the presence of + * these two functions implies the NSS version supports AES GCM. + */ +static PRBool +ssl3_HasGCMSupport(void) +{ + (void)PR_CallOnce(&resolvePK11CryptOnce, ssl3_ResolvePK11CryptFunctions); + return pk11_encrypt != NULL; +} + +/* On this socket, disable the GCM cipher suites */ +SECStatus +ssl3_DisableGCMSuites(sslSocket * ss) +{ + unsigned int i; + + for (i = 0; i < PR_ARRAY_SIZE(cipher_suite_defs); i++) { + const ssl3CipherSuiteDef *cipher_def = &cipher_suite_defs[i]; + if (cipher_def->bulk_cipher_alg == cipher_aes_128_gcm) { + SECStatus rv = ssl3_CipherPrefSet(ss, cipher_def->cipher_suite, + PR_FALSE); + PORT_Assert(rv == SECSuccess); /* else is coding error */ + } + } + return SECSuccess; +} + +static SECStatus +ssl3_AESGCM(ssl3KeyMaterial *keys, + PRBool doDecrypt, + unsigned char *out, + int *outlen, + int maxout, + const unsigned char *in, + int inlen, + const unsigned char *additionalData, + int additionalDataLen) +{ + SECItem param; + SECStatus rv = SECFailure; + unsigned char nonce[12]; + unsigned int uOutLen; + CK_GCM_PARAMS gcmParams; + + static const int tagSize = 16; + static const int explicitNonceLen = 8; + + /* See https://tools.ietf.org/html/rfc5288#section-3 for details of how the + * nonce is formed. */ + memcpy(nonce, keys->write_iv, 4); + if (doDecrypt) { + memcpy(nonce + 4, in, explicitNonceLen); + in += explicitNonceLen; + inlen -= explicitNonceLen; + *outlen = 0; + } else { + if (maxout < explicitNonceLen) { + PORT_SetError(SEC_ERROR_INPUT_LEN); + return SECFailure; + } + /* Use the 64-bit sequence number as the explicit nonce. */ + memcpy(nonce + 4, additionalData, explicitNonceLen); + memcpy(out, additionalData, explicitNonceLen); + out += explicitNonceLen; + maxout -= explicitNonceLen; + *outlen = explicitNonceLen; + } + + param.type = siBuffer; + param.data = (unsigned char *) &gcmParams; + param.len = sizeof(gcmParams); + gcmParams.pIv = nonce; + gcmParams.ulIvLen = sizeof(nonce); + gcmParams.pAAD = (unsigned char *)additionalData; /* const cast */ + gcmParams.ulAADLen = additionalDataLen; + gcmParams.ulTagBits = tagSize * 8; + + if (doDecrypt) { + rv = pk11_decrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, + maxout, in, inlen); + } else { + rv = pk11_encrypt(keys->write_key, CKM_AES_GCM, ¶m, out, &uOutLen, + maxout, in, inlen); + } + *outlen += (int) uOutLen; + + return rv; +} + +#ifndef NO_PKCS11_BYPASS +static SECStatus +ssl3_AESGCMBypass(ssl3KeyMaterial *keys, + PRBool doDecrypt, + unsigned char *out, + int *outlen, + int maxout, + const unsigned char *in, + int inlen, + const unsigned char *additionalData, + int additionalDataLen) +{ + SECStatus rv = SECFailure; + unsigned char nonce[12]; + unsigned int uOutLen; + AESContext *cx; + CK_GCM_PARAMS gcmParams; + + static const int tagSize = 16; + static const int explicitNonceLen = 8; + + /* See https://tools.ietf.org/html/rfc5288#section-3 for details of how the + * nonce is formed. */ + PORT_Assert(keys->write_iv_item.len == 4); + if (keys->write_iv_item.len != 4) { + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); + return SECFailure; + } + memcpy(nonce, keys->write_iv_item.data, 4); + if (doDecrypt) { + memcpy(nonce + 4, in, explicitNonceLen); + in += explicitNonceLen; + inlen -= explicitNonceLen; + *outlen = 0; + } else { + if (maxout < explicitNonceLen) { + PORT_SetError(SEC_ERROR_INPUT_LEN); + return SECFailure; + } + /* Use the 64-bit sequence number as the explicit nonce. */ + memcpy(nonce + 4, additionalData, explicitNonceLen); + memcpy(out, additionalData, explicitNonceLen); + out += explicitNonceLen; + maxout -= explicitNonceLen; + *outlen = explicitNonceLen; + } + + gcmParams.pIv = nonce; + gcmParams.ulIvLen = sizeof(nonce); + gcmParams.pAAD = (unsigned char *)additionalData; /* const cast */ + gcmParams.ulAADLen = additionalDataLen; + gcmParams.ulTagBits = tagSize * 8; + + cx = (AESContext *)keys->cipher_context; + rv = AES_InitContext(cx, keys->write_key_item.data, + keys->write_key_item.len, + (unsigned char *)&gcmParams, NSS_AES_GCM, !doDecrypt, + AES_BLOCK_SIZE); + if (rv != SECSuccess) { + return rv; + } + if (doDecrypt) { + rv = AES_Decrypt(cx, out, &uOutLen, maxout, in, inlen); + } else { + rv = AES_Encrypt(cx, out, &uOutLen, maxout, in, inlen); + } + AES_DestroyContext(cx, PR_FALSE); + *outlen += (int) uOutLen; + + return rv; +} +#endif + +static SECStatus +ssl3_ChaCha20Poly1305( + ssl3KeyMaterial *keys, + PRBool doDecrypt, + unsigned char *out, + int *outlen, + int maxout, + const unsigned char *in, + int inlen, + const unsigned char *additionalData, + int additionalDataLen) +{ + SECItem param; + SECStatus rv = SECFailure; + unsigned int uOutLen; + CK_NSS_AEAD_PARAMS aeadParams; + static const int tagSize = 16; + + param.type = siBuffer; + param.len = sizeof(aeadParams); + param.data = (unsigned char *) &aeadParams; + memset(&aeadParams, 0, sizeof(aeadParams)); + aeadParams.pIv = (unsigned char *) additionalData; + aeadParams.ulIvLen = 8; + aeadParams.pAAD = (unsigned char *) additionalData; + aeadParams.ulAADLen = additionalDataLen; + aeadParams.ulTagLen = tagSize; + + if (doDecrypt) { + rv = pk11_decrypt(keys->write_key, CKM_NSS_CHACHA20_POLY1305, ¶m, + out, &uOutLen, maxout, in, inlen); + } else { + rv = pk11_encrypt(keys->write_key, CKM_NSS_CHACHA20_POLY1305, ¶m, + out, &uOutLen, maxout, in, inlen); + } + *outlen = (int) uOutLen; + + return rv; +} + /* Initialize encryption and MAC contexts for pending spec. * Master Secret already is derived. * Caller holds Spec write lock. @@ -1748,14 +2121,31 @@ ssl3_InitPendingContextsPKCS11(sslSocket *ss) pwSpec = ss->ssl3.pwSpec; cipher_def = pwSpec->cipher_def; macLength = pwSpec->mac_size; + calg = cipher_def->calg; + PORT_Assert(alg2Mech[calg].calg == calg); + + pwSpec->client.write_mac_context = NULL; + pwSpec->server.write_mac_context = NULL; + + if (calg == calg_aes_gcm || calg == calg_chacha20) { + pwSpec->encode = NULL; + pwSpec->decode = NULL; + pwSpec->destroy = NULL; + pwSpec->encodeContext = NULL; + pwSpec->decodeContext = NULL; + if (calg == calg_aes_gcm) { + pwSpec->aead = ssl3_AESGCM; + } else { + pwSpec->aead = ssl3_ChaCha20Poly1305; + } + return SECSuccess; + } /* ** Now setup the MAC contexts, ** crypto contexts are setup below. */ - pwSpec->client.write_mac_context = NULL; - pwSpec->server.write_mac_context = NULL; mac_mech = pwSpec->mac_def->mmech; mac_param.data = (unsigned char *)&macLength; mac_param.len = sizeof(macLength); @@ -1778,9 +2168,6 @@ ssl3_InitPendingContextsPKCS11(sslSocket *ss) ** Now setup the crypto contexts. */ - calg = cipher_def->calg; - PORT_Assert(alg2Mech[calg].calg == calg); - if (calg == calg_null) { pwSpec->encode = Null_Cipher; pwSpec->decode = Null_Cipher; @@ -1988,10 +2375,8 @@ static SECStatus ssl3_ComputeRecordMAC( ssl3CipherSpec * spec, PRBool useServerMacKey, - PRBool isDTLS, - SSL3ContentType type, - SSL3ProtocolVersion version, - SSL3SequenceNumber seq_num, + const unsigned char *header, + unsigned int headerLen, const SSL3Opaque * input, int inputLength, unsigned char * outbuf, @@ -1999,56 +2384,8 @@ ssl3_ComputeRecordMAC( { const ssl3MACDef * mac_def; SECStatus rv; -#ifndef NO_PKCS11_BYPASS - PRBool isTLS; -#endif - unsigned int tempLen; - unsigned char temp[MAX_MAC_LENGTH]; - - temp[0] = (unsigned char)(seq_num.high >> 24); - temp[1] = (unsigned char)(seq_num.high >> 16); - temp[2] = (unsigned char)(seq_num.high >> 8); - temp[3] = (unsigned char)(seq_num.high >> 0); - temp[4] = (unsigned char)(seq_num.low >> 24); - temp[5] = (unsigned char)(seq_num.low >> 16); - temp[6] = (unsigned char)(seq_num.low >> 8); - temp[7] = (unsigned char)(seq_num.low >> 0); - temp[8] = type; - - /* TLS MAC includes the record's version field, SSL's doesn't. - ** We decide which MAC defintiion to use based on the version of - ** the protocol that was negotiated when the spec became current, - ** NOT based on the version value in the record itself. - ** But, we use the record'v version value in the computation. - */ - if (spec->version <= SSL_LIBRARY_VERSION_3_0) { - temp[9] = MSB(inputLength); - temp[10] = LSB(inputLength); - tempLen = 11; -#ifndef NO_PKCS11_BYPASS - isTLS = PR_FALSE; -#endif - } else { - /* New TLS hash includes version. */ - if (isDTLS) { - SSL3ProtocolVersion dtls_version; - dtls_version = dtls_TLSVersionToDTLSVersion(version); - temp[9] = MSB(dtls_version); - temp[10] = LSB(dtls_version); - } else { - temp[9] = MSB(version); - temp[10] = LSB(version); - } - temp[11] = MSB(inputLength); - temp[12] = LSB(inputLength); - tempLen = 13; -#ifndef NO_PKCS11_BYPASS - isTLS = PR_TRUE; -#endif - } - - PRINT_BUF(95, (NULL, "frag hash1: temp", temp, tempLen)); + PRINT_BUF(95, (NULL, "frag hash1: header", header, headerLen)); PRINT_BUF(95, (NULL, "frag hash1: input", input, inputLength)); mac_def = spec->mac_def; @@ -2093,7 +2430,10 @@ ssl3_ComputeRecordMAC( return SECFailure; } - if (!isTLS) { + if (spec->version <= SSL_LIBRARY_VERSION_3_0) { + unsigned int tempLen; + unsigned char temp[MAX_MAC_LENGTH]; + /* compute "inner" part of SSL3 MAC */ hashObj->begin(write_mac_context); if (useServerMacKey) @@ -2105,7 +2445,7 @@ ssl3_ComputeRecordMAC( spec->client.write_mac_key_item.data, spec->client.write_mac_key_item.len); hashObj->update(write_mac_context, mac_pad_1, pad_bytes); - hashObj->update(write_mac_context, temp, tempLen); + hashObj->update(write_mac_context, header, headerLen); hashObj->update(write_mac_context, input, inputLength); hashObj->end(write_mac_context, temp, &tempLen, sizeof temp); @@ -2136,7 +2476,7 @@ ssl3_ComputeRecordMAC( } if (rv == SECSuccess) { HMAC_Begin(cx); - HMAC_Update(cx, temp, tempLen); + HMAC_Update(cx, header, headerLen); HMAC_Update(cx, input, inputLength); rv = HMAC_Finish(cx, outbuf, outLength, spec->mac_size); HMAC_Destroy(cx, PR_FALSE); @@ -2150,7 +2490,7 @@ ssl3_ComputeRecordMAC( (useServerMacKey ? spec->server.write_mac_context : spec->client.write_mac_context); rv = PK11_DigestBegin(mac_context); - rv |= PK11_DigestOp(mac_context, temp, tempLen); + rv |= PK11_DigestOp(mac_context, header, headerLen); rv |= PK11_DigestOp(mac_context, input, inputLength); rv |= PK11_DigestFinal(mac_context, outbuf, outLength, spec->mac_size); } @@ -2190,10 +2530,8 @@ static SECStatus ssl3_ComputeRecordMACConstantTime( ssl3CipherSpec * spec, PRBool useServerMacKey, - PRBool isDTLS, - SSL3ContentType type, - SSL3ProtocolVersion version, - SSL3SequenceNumber seq_num, + const unsigned char *header, + unsigned int headerLen, const SSL3Opaque * input, int inputLen, int originalLen, @@ -2205,9 +2543,7 @@ ssl3_ComputeRecordMACConstantTime( PK11Context * mac_context; SECItem param; SECStatus rv; - unsigned char header[13]; PK11SymKey * key; - int recordLength; PORT_Assert(inputLen >= spec->mac_size); PORT_Assert(originalLen >= inputLen); @@ -2223,42 +2559,15 @@ ssl3_ComputeRecordMACConstantTime( return SECSuccess; } - header[0] = (unsigned char)(seq_num.high >> 24); - header[1] = (unsigned char)(seq_num.high >> 16); - header[2] = (unsigned char)(seq_num.high >> 8); - header[3] = (unsigned char)(seq_num.high >> 0); - header[4] = (unsigned char)(seq_num.low >> 24); - header[5] = (unsigned char)(seq_num.low >> 16); - header[6] = (unsigned char)(seq_num.low >> 8); - header[7] = (unsigned char)(seq_num.low >> 0); - header[8] = type; - macType = CKM_NSS_HMAC_CONSTANT_TIME; - recordLength = inputLen - spec->mac_size; if (spec->version <= SSL_LIBRARY_VERSION_3_0) { macType = CKM_NSS_SSL3_MAC_CONSTANT_TIME; - header[9] = recordLength >> 8; - header[10] = recordLength; - params.ulHeaderLen = 11; - } else { - if (isDTLS) { - SSL3ProtocolVersion dtls_version; - - dtls_version = dtls_TLSVersionToDTLSVersion(version); - header[9] = dtls_version >> 8; - header[10] = dtls_version; - } else { - header[9] = version >> 8; - header[10] = version; - } - header[11] = recordLength >> 8; - header[12] = recordLength; - params.ulHeaderLen = 13; } params.macAlg = spec->mac_def->mmech; params.ulBodyTotalLen = originalLen; - params.pHeader = header; + params.pHeader = (unsigned char *) header; /* const cast */ + params.ulHeaderLen = headerLen; param.data = (unsigned char*) ¶ms; param.len = sizeof(params); @@ -2291,9 +2600,8 @@ fallback: /* ssl3_ComputeRecordMAC expects the MAC to have been removed from the * length already. */ inputLen -= spec->mac_size; - return ssl3_ComputeRecordMAC(spec, useServerMacKey, isDTLS, type, - version, seq_num, input, inputLen, - outbuf, outLen); + return ssl3_ComputeRecordMAC(spec, useServerMacKey, header, headerLen, + input, inputLen, outbuf, outLen); } static PRBool @@ -2345,6 +2653,8 @@ ssl3_CompressMACEncryptRecord(ssl3CipherSpec * cwSpec, PRUint16 headerLen; int ivLen = 0; int cipherBytes = 0; + unsigned char pseudoHeader[13]; + unsigned int pseudoHeaderLen; cipher_def = cwSpec->cipher_def; headerLen = isDTLS ? DTLS_RECORD_HEADER_LENGTH : SSL3_RECORD_HEADER_LENGTH; @@ -2390,86 +2700,117 @@ ssl3_CompressMACEncryptRecord(ssl3CipherSpec * cwSpec, contentLen = outlen; } - /* - * Add the MAC - */ - rv = ssl3_ComputeRecordMAC( cwSpec, isServer, isDTLS, - type, cwSpec->version, cwSpec->write_seq_num, pIn, contentLen, - wrBuf->buf + headerLen + ivLen + contentLen, &macLen); - if (rv != SECSuccess) { - ssl_MapLowLevelError(SSL_ERROR_MAC_COMPUTATION_FAILURE); - return SECFailure; - } - p1Len = contentLen; - p2Len = macLen; - fragLen = contentLen + macLen; /* needs to be encrypted */ - PORT_Assert(fragLen <= MAX_FRAGMENT_LENGTH + 1024); + pseudoHeaderLen = ssl3_BuildRecordPseudoHeader( + pseudoHeader, cwSpec->write_seq_num, type, + cwSpec->version >= SSL_LIBRARY_VERSION_TLS_1_0, cwSpec->version, + isDTLS, contentLen); + PORT_Assert(pseudoHeaderLen <= sizeof(pseudoHeader)); + if (cipher_def->type == type_aead) { + const int nonceLen = cipher_def->explicit_nonce_size; + const int tagLen = cipher_def->tag_size; - /* - * Pad the text (if we're doing a block cipher) - * then Encrypt it - */ - if (cipher_def->type == type_block) { - unsigned char * pBuf; - int padding_length; - int i; - - oddLen = contentLen % cipher_def->block_size; - /* Assume blockSize is a power of two */ - padding_length = cipher_def->block_size - 1 - - ((fragLen) & (cipher_def->block_size - 1)); - fragLen += padding_length + 1; - PORT_Assert((fragLen % cipher_def->block_size) == 0); - - /* Pad according to TLS rules (also acceptable to SSL3). */ - pBuf = &wrBuf->buf[headerLen + ivLen + fragLen - 1]; - for (i = padding_length + 1; i > 0; --i) { - *pBuf-- = padding_length; - } - /* now, if contentLen is not a multiple of block size, fix it */ - p2Len = fragLen - p1Len; - } - if (p1Len < 256) { - oddLen = p1Len; - p1Len = 0; - } else { - p1Len -= oddLen; - } - if (oddLen) { - p2Len += oddLen; - PORT_Assert( (cipher_def->block_size < 2) || \ - (p2Len % cipher_def->block_size) == 0); - memmove(wrBuf->buf + headerLen + ivLen + p1Len, pIn + p1Len, oddLen); - } - if (p1Len > 0) { - int cipherBytesPart1 = -1; - rv = cwSpec->encode( cwSpec->encodeContext, - wrBuf->buf + headerLen + ivLen, /* output */ - &cipherBytesPart1, /* actual outlen */ - p1Len, /* max outlen */ - pIn, p1Len); /* input, and inputlen */ - PORT_Assert(rv == SECSuccess && cipherBytesPart1 == (int) p1Len); - if (rv != SECSuccess || cipherBytesPart1 != (int) p1Len) { - PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); + if (headerLen + nonceLen + contentLen + tagLen > wrBuf->space) { + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); return SECFailure; } - cipherBytes += cipherBytesPart1; - } - if (p2Len > 0) { - int cipherBytesPart2 = -1; - rv = cwSpec->encode( cwSpec->encodeContext, - wrBuf->buf + headerLen + ivLen + p1Len, - &cipherBytesPart2, /* output and actual outLen */ - p2Len, /* max outlen */ - wrBuf->buf + headerLen + ivLen + p1Len, - p2Len); /* input and inputLen*/ - PORT_Assert(rv == SECSuccess && cipherBytesPart2 == (int) p2Len); - if (rv != SECSuccess || cipherBytesPart2 != (int) p2Len) { + + cipherBytes = contentLen; + rv = cwSpec->aead( + isServer ? &cwSpec->server : &cwSpec->client, + PR_FALSE, /* do encrypt */ + wrBuf->buf + headerLen, /* output */ + &cipherBytes, /* out len */ + wrBuf->space - headerLen, /* max out */ + pIn, contentLen, /* input */ + pseudoHeader, pseudoHeaderLen); + if (rv != SECSuccess) { PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); return SECFailure; } - cipherBytes += cipherBytesPart2; - } + } else { + /* + * Add the MAC + */ + rv = ssl3_ComputeRecordMAC(cwSpec, isServer, + pseudoHeader, pseudoHeaderLen, pIn, contentLen, + wrBuf->buf + headerLen + ivLen + contentLen, &macLen); + if (rv != SECSuccess) { + ssl_MapLowLevelError(SSL_ERROR_MAC_COMPUTATION_FAILURE); + return SECFailure; + } + p1Len = contentLen; + p2Len = macLen; + fragLen = contentLen + macLen; /* needs to be encrypted */ + PORT_Assert(fragLen <= MAX_FRAGMENT_LENGTH + 1024); + + /* + * Pad the text (if we're doing a block cipher) + * then Encrypt it + */ + if (cipher_def->type == type_block) { + unsigned char * pBuf; + int padding_length; + int i; + + oddLen = contentLen % cipher_def->block_size; + /* Assume blockSize is a power of two */ + padding_length = cipher_def->block_size - 1 - + ((fragLen) & (cipher_def->block_size - 1)); + fragLen += padding_length + 1; + PORT_Assert((fragLen % cipher_def->block_size) == 0); + + /* Pad according to TLS rules (also acceptable to SSL3). */ + pBuf = &wrBuf->buf[headerLen + ivLen + fragLen - 1]; + for (i = padding_length + 1; i > 0; --i) { + *pBuf-- = padding_length; + } + /* now, if contentLen is not a multiple of block size, fix it */ + p2Len = fragLen - p1Len; + } + if (p1Len < 256) { + oddLen = p1Len; + p1Len = 0; + } else { + p1Len -= oddLen; + } + if (oddLen) { + p2Len += oddLen; + PORT_Assert( (cipher_def->block_size < 2) || \ + (p2Len % cipher_def->block_size) == 0); + memmove(wrBuf->buf + headerLen + ivLen + p1Len, pIn + p1Len, + oddLen); + } + if (p1Len > 0) { + int cipherBytesPart1 = -1; + rv = cwSpec->encode( cwSpec->encodeContext, + wrBuf->buf + headerLen + ivLen, /* output */ + &cipherBytesPart1, /* actual outlen */ + p1Len, /* max outlen */ + pIn, p1Len); /* input, and inputlen */ + PORT_Assert(rv == SECSuccess && cipherBytesPart1 == (int) p1Len); + if (rv != SECSuccess || cipherBytesPart1 != (int) p1Len) { + PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); + return SECFailure; + } + cipherBytes += cipherBytesPart1; + } + if (p2Len > 0) { + int cipherBytesPart2 = -1; + rv = cwSpec->encode( cwSpec->encodeContext, + wrBuf->buf + headerLen + ivLen + p1Len, + &cipherBytesPart2, /* output and actual outLen */ + p2Len, /* max outlen */ + wrBuf->buf + headerLen + ivLen + p1Len, + p2Len); /* input and inputLen*/ + PORT_Assert(rv == SECSuccess && cipherBytesPart2 == (int) p2Len); + if (rv != SECSuccess || cipherBytesPart2 != (int) p2Len) { + PORT_SetError(SSL_ERROR_ENCRYPTION_FAILURE); + return SECFailure; + } + cipherBytes += cipherBytesPart2; + } + } + PORT_Assert(cipherBytes <= MAX_FRAGMENT_LENGTH + 1024); wrBuf->len = cipherBytes + headerLen; @@ -2529,12 +2870,14 @@ ssl3_CompressMACEncryptRecord(ssl3CipherSpec * cwSpec, * Forces the use of the provided epoch * ssl_SEND_FLAG_CAP_RECORD_VERSION * Caps the record layer version number of TLS ClientHello to { 3, 1 } - * (TLS 1.0). Some TLS 1.0 servers (which seem to use F5 BIG-IP) ignore + * (TLS 1.0). Some TLS 1.0 servers (which seem to use F5 BIG-IP) ignore * ClientHello.client_version and use the record layer version number * (TLSPlaintext.version) instead when negotiating protocol versions. In * addition, if the record layer version number of ClientHello is { 3, 2 } - * (TLS 1.1) or higher, these servers reset the TCP connections. Set this - * flag to work around such servers. + * (TLS 1.1) or higher, these servers reset the TCP connections. Lastly, + * some F5 BIG-IP servers hang if a record containing a ClientHello has a + * version greater than 0x0301 and a length greater than 255. Set this flag + * to work around such servers. */ PRInt32 ssl3_SendRecord( sslSocket * ss, @@ -2552,7 +2895,7 @@ ssl3_SendRecord( sslSocket * ss, SSL_TRC(3, ("%d: SSL3[%d] SendRecord type: %s nIn=%d", SSL_GETPID(), ss->fd, ssl3_DecodeContentType(type), nIn)); - PRINT_BUF(3, (ss, "Send record (plain text)", pIn, nIn)); + PRINT_BUF(50, (ss, "Send record (plain text)", pIn, nIn)); PORT_Assert( ss->opt.noLocks || ssl_HaveXmitBufLock(ss) ); @@ -3012,9 +3355,6 @@ SSL3_SendAlert(sslSocket *ss, SSL3AlertLevel level, SSL3AlertDescription desc) static SECStatus ssl3_IllegalParameter(sslSocket *ss) { - PRBool isTLS; - - isTLS = (PRBool)(ss->ssl3.pwSpec->version > SSL_LIBRARY_VERSION_3_0); (void)SSL3_SendAlert(ss, alert_fatal, illegal_parameter); PORT_SetError(ss->sec.isServer ? SSL_ERROR_BAD_CLIENT : SSL_ERROR_BAD_SERVER ); @@ -3538,7 +3878,6 @@ ssl3_DeriveConnectionKeysPKCS11(sslSocket *ss) } key_material_params.bIsExport = (CK_BBOOL)(kea_def->is_limited); - /* was: (CK_BBOOL)(cipher_def->keygen_mode != kg_strong); */ key_material_params.RandomInfo.pClientRandom = cr; key_material_params.RandomInfo.ulClientRandomLen = SSL3_RANDOM_LENGTH; @@ -4875,6 +5214,10 @@ ssl3_SendClientHello(sslSocket *ss, PRBool resending) ssl3_DisableNonDTLSSuites(ss); } + if (!ssl3_HasGCMSupport()) { + ssl3_DisableGCMSuites(ss); + } + /* how many suites are permitted by policy and user preference? */ num_suites = count_cipher_suites(ss, ss->ssl3.policy, PR_TRUE); if (!num_suites) @@ -4966,7 +5309,7 @@ ssl3_SendClientHello(sslSocket *ss, PRBool resending) } for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i]; - if (config_match(suite, ss->ssl3.policy, PR_TRUE)) { + if (config_match(suite, ss->ssl3.policy, PR_TRUE, &ss->vrange)) { actual_count++; if (actual_count > num_suites) { /* set error card removal/insertion error */ @@ -5027,7 +5370,7 @@ ssl3_SendClientHello(sslSocket *ss, PRBool resending) } flags = 0; - if (!ss->firstHsDone && !requestingResume && !IS_DTLS(ss)) { + if (!ss->firstHsDone && !IS_DTLS(ss)) { flags |= ssl_SEND_FLAG_CAP_RECORD_VERSION; } rv = ssl3_FlushHandshake(ss, flags); @@ -5265,7 +5608,6 @@ SSL3_ShutdownServerCache(void) } PZ_Unlock(symWrapKeysLock); - ssl_FreeSessionCacheLocks(); return SECSuccess; } @@ -5317,7 +5659,7 @@ getWrappingKey( sslSocket * ss, pSymWrapKey = &symWrapKeys[symWrapMechIndex].symWrapKey[exchKeyType]; - ssl_InitSessionCacheLocks(PR_TRUE); + ssl_InitSessionCacheLocks(); PZ_Lock(symWrapKeysLock); @@ -6032,15 +6374,19 @@ ssl3_HandleServerHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i]; if (temp == suite->cipher_suite) { - if (!config_match(suite, ss->ssl3.policy, PR_TRUE)) { + SSLVersionRange vrange = {ss->version, ss->version}; + if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) { + /* config_match already checks whether the cipher suite is + * acceptable for the version, but the check is repeated here + * in order to give a more precise error code. */ + if (!ssl3_CipherSuiteAllowedForVersionRange(temp, &vrange)) { + desc = handshake_failure; + errCode = SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION; + goto alert_loser; + } + break; /* failure */ } - if (!ssl3_CipherSuiteAllowedForVersion(suite->cipher_suite, - ss->version)) { - desc = handshake_failure; - errCode = SSL_ERROR_CIPHER_DISALLOWED_FOR_VERSION; - goto alert_loser; - } suite_found = PR_TRUE; break; /* success */ @@ -7003,35 +7349,42 @@ ssl3_RestartHandshakeAfterCertReq(sslSocket * ss, return rv; } -PRBool -ssl3_CanFalseStart(sslSocket *ss) { - PRBool rv; +static SECStatus +ssl3_CheckFalseStart(sslSocket *ss) +{ + SECStatus rv; + PRBool maybeFalseStart = PR_TRUE; PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) ); + PORT_Assert( !ss->ssl3.hs.authCertificatePending ); - /* XXX: does not take into account whether we are waiting for - * SSL_AuthCertificateComplete or SSL_RestartHandshakeAfterCertReq. If/when - * that is done, this function could return different results each time it - * would be called. - */ + /* An attacker can control the selected ciphersuite so we only wish to + * do False Start in the case that the selected ciphersuite is + * sufficiently strong that the attack can gain no advantage. + * Therefore we always require an 80-bit cipher. */ ssl_GetSpecReadLock(ss); - rv = ss->opt.enableFalseStart && - !ss->sec.isServer && - !ss->ssl3.hs.isResuming && - ss->ssl3.cwSpec && - - /* An attacker can control the selected ciphersuite so we only wish to - * do False Start in the case that the selected ciphersuite is - * sufficiently strong that the attack can gain no advantage. - * Therefore we require an 80-bit cipher and a forward-secret key - * exchange. */ - ss->ssl3.cwSpec->cipher_def->secret_key_size >= 10 && - (ss->ssl3.hs.kea_def->kea == kea_dhe_dss || - ss->ssl3.hs.kea_def->kea == kea_dhe_rsa || - ss->ssl3.hs.kea_def->kea == kea_ecdhe_ecdsa || - ss->ssl3.hs.kea_def->kea == kea_ecdhe_rsa); + if (ss->ssl3.cwSpec->cipher_def->secret_key_size < 10) { + ss->ssl3.hs.canFalseStart = PR_FALSE; + maybeFalseStart = PR_FALSE; + } ssl_ReleaseSpecReadLock(ss); + if (!maybeFalseStart) { + return SECSuccess; + } + + if (!ss->canFalseStartCallback) { + rv = SSL_DefaultCanFalseStart(ss->fd, &ss->ssl3.hs.canFalseStart); + } else { + rv = (ss->canFalseStartCallback)(ss->fd, + ss->canFalseStartCallbackData, + &ss->ssl3.hs.canFalseStart); + } + + if (rv != SECSuccess) { + ss->ssl3.hs.canFalseStart = PR_FALSE; + } + return rv; } @@ -7159,20 +7512,59 @@ ssl3_SendClientSecondRound(sslSocket *ss) goto loser; /* err code was set. */ } - /* XXX: If the server's certificate hasn't been authenticated by this - * point, then we may be leaking this NPN message to an attacker. + /* This must be done after we've set ss->ssl3.cwSpec in + * ssl3_SendChangeCipherSpecs because SSL_GetChannelInfo uses information + * from cwSpec. This must be done before we call ssl3_CheckFalseStart + * because the false start callback (if any) may need the information from + * the functions that depend on this being set. */ + ss->enoughFirstHsDone = PR_TRUE; + if (!ss->firstHsDone) { + /* XXX: If the server's certificate hasn't been authenticated by this + * point, then we may be leaking this NPN message to an attacker. + */ rv = ssl3_SendNextProto(ss); if (rv != SECSuccess) { goto loser; /* err code was set. */ } } + rv = ssl3_SendEncryptedExtensions(ss); if (rv != SECSuccess) { goto loser; /* err code was set. */ } + if (!ss->firstHsDone) { + if (ss->opt.enableFalseStart) { + if (!ss->ssl3.hs.authCertificatePending) { + /* When we fix bug 589047, we will need to know whether we are + * false starting before we try to flush the client second + * round to the network. With that in mind, we purposefully + * call ssl3_CheckFalseStart before calling ssl3_SendFinished, + * which includes a call to ssl3_FlushHandshake, so that + * no application develops a reliance on such flushing being + * done before its false start callback is called. + */ + ssl_ReleaseXmitBufLock(ss); + rv = ssl3_CheckFalseStart(ss); + ssl_GetXmitBufLock(ss); + if (rv != SECSuccess) { + goto loser; + } + } else { + /* The certificate authentication and the server's Finished + * message are racing each other. If the certificate + * authentication wins, then we will try to false start in + * ssl3_AuthCertificateComplete. + */ + SSL_TRC(3, ("%d: SSL3[%p]: deferring false start check because" + " certificate authentication is still pending.", + SSL_GETPID(), ss->fd)); + } + } + } + rv = ssl3_SendFinished(ss, 0); if (rv != SECSuccess) { goto loser; /* err code was set. */ @@ -7185,8 +7577,16 @@ ssl3_SendClientSecondRound(sslSocket *ss) else ss->ssl3.hs.ws = wait_change_cipher; - /* Do the handshake callback for sslv3 here, if we can false start. */ - if (ss->handshakeCallback != NULL && ssl3_CanFalseStart(ss)) { + if (ss->handshakeCallback && + (ss->ssl3.hs.canFalseStart && !ss->canFalseStartCallback)) { + /* Call the handshake callback here for backwards compatibility with + * applications that were using false start before + * canFalseStartCallback was added. Note that we do this after calling + * ssl3_SendFinished, which includes a call to ssl3_FlushHandshake, + * just in case the application is relying on having the handshake + * messages flushed to the network before its handshake callback is + * called. + */ (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); } @@ -7662,6 +8062,10 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) ssl3_DisableNonDTLSSuites(ss); } + if (!ssl3_HasGCMSupport()) { + ssl3_DisableGCMSuites(ss); + } + #ifdef PARANOID /* Look for a matching cipher suite. */ j = ssl3_config_match_init(ss); @@ -7677,6 +8081,9 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) */ if (sid) do { ssl3CipherSuiteCfg *suite; +#ifdef PARANOID + SSLVersionRange vrange = {ss->version, ss->version}; +#endif /* Check that the cached compression method is still enabled. */ if (!compressionEnabled(ss, sid->u.ssl3.compression)) @@ -7705,7 +8112,7 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) * The product policy won't change during the process lifetime. * Implemented ("isPresent") shouldn't change for servers. */ - if (!config_match(suite, ss->ssl3.policy, PR_TRUE)) + if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) break; #else if (!suite->enabled) @@ -7753,9 +8160,8 @@ ssl3_HandleClientHello(sslSocket *ss, SSL3Opaque *b, PRUint32 length) */ for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) { ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j]; - if (!config_match(suite, ss->ssl3.policy, PR_TRUE) || - !ssl3_CipherSuiteAllowedForVersion(suite->cipher_suite, - ss->version)) { + SSLVersionRange vrange = {ss->version, ss->version}; + if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) { continue; } for (i = 0; i + 1 < suites.len; i += 2) { @@ -8288,9 +8694,8 @@ ssl3_HandleV2ClientHello(sslSocket *ss, unsigned char *buffer, int length) */ for (j = 0; j < ssl_V3_SUITES_IMPLEMENTED; j++) { ssl3CipherSuiteCfg *suite = &ss->cipherSuites[j]; - if (!config_match(suite, ss->ssl3.policy, PR_TRUE) || - !ssl3_CipherSuiteAllowedForVersion(suite->cipher_suite, - ss->version)) { + SSLVersionRange vrange = {ss->version, ss->version}; + if (!config_match(suite, ss->ssl3.policy, PR_TRUE, &vrange)) { continue; } for (i = 0; i+2 < suite_length; i += 3) { @@ -9801,13 +10206,6 @@ ssl3_AuthCertificate(sslSocket *ss) ss->ssl3.hs.authCertificatePending = PR_TRUE; rv = SECSuccess; - - /* XXX: Async cert validation and False Start don't work together - * safely yet; if we leave False Start enabled, we may end up false - * starting (sending application data) before we - * SSL_AuthCertificateComplete has been called. - */ - ss->opt.enableFalseStart = PR_FALSE; } if (rv != SECSuccess) { @@ -9932,6 +10330,12 @@ ssl3_AuthCertificateComplete(sslSocket *ss, PRErrorCode error) } else if (ss->ssl3.hs.restartTarget != NULL) { sslRestartTarget target = ss->ssl3.hs.restartTarget; ss->ssl3.hs.restartTarget = NULL; + + if (target == ssl3_FinishHandshake) { + SSL_TRC(3,("%d: SSL3[%p]: certificate authentication lost the race" + " with peer's finished message", SSL_GETPID(), ss->fd)); + } + rv = target(ss); /* Even if we blocked here, we have accomplished enough to claim * success. Any remaining work will be taken care of by subsequent @@ -9941,7 +10345,39 @@ ssl3_AuthCertificateComplete(sslSocket *ss, PRErrorCode error) rv = SECSuccess; } } else { - rv = SECSuccess; + SSL_TRC(3, ("%d: SSL3[%p]: certificate authentication won the race" + " with peer's finished message", SSL_GETPID(), ss->fd)); + + PORT_Assert(!ss->firstHsDone); + PORT_Assert(!ss->sec.isServer); + PORT_Assert(!ss->ssl3.hs.isResuming); + PORT_Assert(ss->ssl3.hs.ws == wait_change_cipher || + ss->ssl3.hs.ws == wait_finished || + ss->ssl3.hs.ws == wait_new_session_ticket); + + /* ssl3_SendClientSecondRound deferred the false start check because + * certificate authentication was pending, so we have to do it now. + */ + if (ss->opt.enableFalseStart && + !ss->firstHsDone && + !ss->sec.isServer && + !ss->ssl3.hs.isResuming && + (ss->ssl3.hs.ws == wait_change_cipher || + ss->ssl3.hs.ws == wait_finished || + ss->ssl3.hs.ws == wait_new_session_ticket)) { + rv = ssl3_CheckFalseStart(ss); + if (rv == SECSuccess && + ss->handshakeCallback && + (ss->ssl3.hs.canFalseStart && !ss->canFalseStartCallback)) { + /* Call the handshake callback here for backwards compatibility + * with applications that were using false start before + * canFalseStartCallback was added. + */ + (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); + } + } else { + rv = SECSuccess; + } } done: @@ -10073,7 +10509,6 @@ ssl3_SendNextProto(sslSocket *ss) static void ssl3_RecordKeyLog(sslSocket *ss) { - sslSessionID *sid; SECStatus rv; SECItem *keyData; char buf[14 /* "CLIENT_RANDOM " */ + @@ -10085,8 +10520,6 @@ ssl3_RecordKeyLog(sslSocket *ss) PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss)); - sid = ss->sec.ci.sid; - if (!ssl_keylog_iob) return; @@ -10640,6 +11073,8 @@ xmit_loser: SECStatus ssl3_FinishHandshake(sslSocket * ss) { + PRBool falseStarted; + PORT_Assert( ss->opt.noLocks || ssl_HaveRecvBufLock(ss) ); PORT_Assert( ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss) ); PORT_Assert( ss->ssl3.hs.restartTarget == NULL ); @@ -10647,6 +11082,7 @@ ssl3_FinishHandshake(sslSocket * ss) /* The first handshake is now completed. */ ss->handshake = NULL; ss->firstHsDone = PR_TRUE; + ss->enoughFirstHsDone = PR_TRUE; if (ss->ssl3.hs.cacheSID) { (*ss->sec.cache)(ss->sec.ci.sid); @@ -10654,9 +11090,14 @@ ssl3_FinishHandshake(sslSocket * ss) } ss->ssl3.hs.ws = idle_handshake; + falseStarted = ss->ssl3.hs.canFalseStart; + ss->ssl3.hs.canFalseStart = PR_FALSE; /* False Start phase is complete */ - /* Do the handshake callback for sslv3 here, if we cannot false start. */ - if (ss->handshakeCallback != NULL && !ssl3_CanFalseStart(ss)) { + /* Call the handshake callback for sslv3 here, unless we called it already + * for the case where false start was done without a canFalseStartCallback. + */ + if (ss->handshakeCallback && + !(falseStarted && !ss->canFalseStartCallback)) { (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); } @@ -11222,6 +11663,8 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) unsigned int originalLen = 0; unsigned int good; unsigned int minLength; + unsigned char header[13]; + unsigned int headerLen; PORT_Assert( ss->opt.noLocks || ssl_HaveRecvBufLock(ss) ); @@ -11298,12 +11741,14 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) /* With >= TLS 1.1, CBC records have an explicit IV. */ minLength += cipher_def->iv_size; } + } else if (cipher_def->type == type_aead) { + minLength = cipher_def->explicit_nonce_size + cipher_def->tag_size; } /* We can perform this test in variable time because the record's total * length and the ciphersuite are both public knowledge. */ if (cText->buf->len < minLength) { - goto decrypt_loser; + goto decrypt_loser; } if (cipher_def->type == type_block && @@ -11371,78 +11816,104 @@ ssl3_HandleRecord(sslSocket *ss, SSL3Ciphertext *cText, sslBuffer *databuf) return SECFailure; } - if (cipher_def->type == type_block && - ((cText->buf->len - ivLen) % cipher_def->block_size) != 0) { - goto decrypt_loser; - } + rType = cText->type; + if (cipher_def->type == type_aead) { + /* XXX For many AEAD ciphers, the plaintext is shorter than the + * ciphertext by a fixed byte count, but it is not true in general. + * Each AEAD cipher should provide a function that returns the + * plaintext length for a given ciphertext. */ + unsigned int decryptedLen = + cText->buf->len - cipher_def->explicit_nonce_size - + cipher_def->tag_size; + headerLen = ssl3_BuildRecordPseudoHeader( + header, IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, + rType, isTLS, cText->version, IS_DTLS(ss), decryptedLen); + PORT_Assert(headerLen <= sizeof(header)); + rv = crSpec->aead( + ss->sec.isServer ? &crSpec->client : &crSpec->server, + PR_TRUE, /* do decrypt */ + plaintext->buf, /* out */ + (int*) &plaintext->len, /* outlen */ + plaintext->space, /* maxout */ + cText->buf->buf, /* in */ + cText->buf->len, /* inlen */ + header, headerLen); + if (rv != SECSuccess) { + good = 0; + } + } else { + if (cipher_def->type == type_block && + ((cText->buf->len - ivLen) % cipher_def->block_size) != 0) { + goto decrypt_loser; + } - /* decrypt from cText buf to plaintext. */ - rv = crSpec->decode( - crSpec->decodeContext, plaintext->buf, (int *)&plaintext->len, - plaintext->space, cText->buf->buf + ivLen, cText->buf->len - ivLen); - if (rv != SECSuccess) { - goto decrypt_loser; - } + /* decrypt from cText buf to plaintext. */ + rv = crSpec->decode( + crSpec->decodeContext, plaintext->buf, (int *)&plaintext->len, + plaintext->space, cText->buf->buf + ivLen, cText->buf->len - ivLen); + if (rv != SECSuccess) { + goto decrypt_loser; + } - PRINT_BUF(80, (ss, "cleartext:", plaintext->buf, plaintext->len)); + PRINT_BUF(80, (ss, "cleartext:", plaintext->buf, plaintext->len)); - originalLen = plaintext->len; + originalLen = plaintext->len; - /* If it's a block cipher, check and strip the padding. */ - if (cipher_def->type == type_block) { - const unsigned int blockSize = cipher_def->block_size; - const unsigned int macSize = crSpec->mac_size; + /* If it's a block cipher, check and strip the padding. */ + if (cipher_def->type == type_block) { + const unsigned int blockSize = cipher_def->block_size; + const unsigned int macSize = crSpec->mac_size; - if (crSpec->version <= SSL_LIBRARY_VERSION_3_0) { - good &= SECStatusToMask(ssl_RemoveSSLv3CBCPadding( - plaintext, blockSize, macSize)); - } else { - good &= SECStatusToMask(ssl_RemoveTLSCBCPadding( - plaintext, macSize)); + if (!isTLS) { + good &= SECStatusToMask(ssl_RemoveSSLv3CBCPadding( + plaintext, blockSize, macSize)); + } else { + good &= SECStatusToMask(ssl_RemoveTLSCBCPadding( + plaintext, macSize)); + } } - } - /* compute the MAC */ - rType = cText->type; - if (cipher_def->type == type_block) { - rv = ssl3_ComputeRecordMACConstantTime( - crSpec, (PRBool)(!ss->sec.isServer), - IS_DTLS(ss), rType, cText->version, - IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, - plaintext->buf, plaintext->len, originalLen, - hash, &hashBytes); - - ssl_CBCExtractMAC(plaintext, originalLen, givenHashBuf, - crSpec->mac_size); - givenHash = givenHashBuf; - - /* plaintext->len will always have enough space to remove the MAC - * because in ssl_Remove{SSLv3|TLS}CBCPadding we only adjust - * plaintext->len if the result has enough space for the MAC and we - * tested the unadjusted size against minLength, above. */ - plaintext->len -= crSpec->mac_size; - } else { - /* This is safe because we checked the minLength above. */ - plaintext->len -= crSpec->mac_size; + /* compute the MAC */ + headerLen = ssl3_BuildRecordPseudoHeader( + header, IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, + rType, isTLS, cText->version, IS_DTLS(ss), + plaintext->len - crSpec->mac_size); + PORT_Assert(headerLen <= sizeof(header)); + if (cipher_def->type == type_block) { + rv = ssl3_ComputeRecordMACConstantTime( + crSpec, (PRBool)(!ss->sec.isServer), header, headerLen, + plaintext->buf, plaintext->len, originalLen, + hash, &hashBytes); + + ssl_CBCExtractMAC(plaintext, originalLen, givenHashBuf, + crSpec->mac_size); + givenHash = givenHashBuf; + + /* plaintext->len will always have enough space to remove the MAC + * because in ssl_Remove{SSLv3|TLS}CBCPadding we only adjust + * plaintext->len if the result has enough space for the MAC and we + * tested the unadjusted size against minLength, above. */ + plaintext->len -= crSpec->mac_size; + } else { + /* This is safe because we checked the minLength above. */ + plaintext->len -= crSpec->mac_size; - rv = ssl3_ComputeRecordMAC( - crSpec, (PRBool)(!ss->sec.isServer), - IS_DTLS(ss), rType, cText->version, - IS_DTLS(ss) ? cText->seq_num : crSpec->read_seq_num, - plaintext->buf, plaintext->len, - hash, &hashBytes); + rv = ssl3_ComputeRecordMAC( + crSpec, (PRBool)(!ss->sec.isServer), header, headerLen, + plaintext->buf, plaintext->len, hash, &hashBytes); - /* We can read the MAC directly from the record because its location is - * public when a stream cipher is used. */ - givenHash = plaintext->buf + plaintext->len; - } + /* We can read the MAC directly from the record because its location + * is public when a stream cipher is used. */ + givenHash = plaintext->buf + plaintext->len; + } - good &= SECStatusToMask(rv); + good &= SECStatusToMask(rv); - if (hashBytes != (unsigned)crSpec->mac_size || - NSS_SecureMemcmp(givenHash, hash, crSpec->mac_size) != 0) { - /* We're allowed to leak whether or not the MAC check was correct */ - good = 0; + if (hashBytes != (unsigned)crSpec->mac_size || + NSS_SecureMemcmp(givenHash, hash, crSpec->mac_size) != 0) { + /* We're allowed to leak whether or not the MAC check was correct */ + good = 0; + } } if (good == 0) { @@ -11966,7 +12437,7 @@ ssl3_ConstructV2CipherSpecsHack(sslSocket *ss, unsigned char *cs, int *size) /* ssl3_config_match_init was called by the caller of this function. */ for (i = 0; i < ssl_V3_SUITES_IMPLEMENTED; i++) { ssl3CipherSuiteCfg *suite = &ss->cipherSuites[i]; - if (config_match(suite, SSL_ALLOWED, PR_TRUE)) { + if (config_match(suite, SSL_ALLOWED, PR_TRUE, &ss->vrange)) { if (cs != NULL) { *cs++ = 0x00; *cs++ = (suite->cipher_suite >> 8) & 0xFF; diff --git a/chromium/net/third_party/nss/ssl/ssl3ecc.c b/chromium/net/third_party/nss/ssl/ssl3ecc.c index 74995f18321..21a5e05cf2d 100644 --- a/chromium/net/third_party/nss/ssl/ssl3ecc.c +++ b/chromium/net/third_party/nss/ssl/ssl3ecc.c @@ -911,7 +911,9 @@ static const ssl3CipherSuite ecdhe_ecdsa_suites[] = { TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_ECDSA_WITH_NULL_SHA, TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 0 /* end of list marker */ @@ -921,7 +923,9 @@ static const ssl3CipherSuite ecdhe_rsa_suites[] = { TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_NULL_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, 0 /* end of list marker */ @@ -932,13 +936,17 @@ static const ssl3CipherSuite ecSuites[] = { TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_ECDSA_WITH_NULL_SHA, TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, TLS_ECDHE_RSA_WITH_NULL_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA, TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, diff --git a/chromium/net/third_party/nss/ssl/ssl3gthr.c b/chromium/net/third_party/nss/ssl/ssl3gthr.c index 6d625152662..7385d6504e8 100644 --- a/chromium/net/third_party/nss/ssl/ssl3gthr.c +++ b/chromium/net/third_party/nss/ssl/ssl3gthr.c @@ -374,9 +374,7 @@ ssl3_GatherCompleteHandshake(sslSocket *ss, int flags) */ if (ss->opt.enableFalseStart) { ssl_GetSSL3HandshakeLock(ss); - canFalseStart = (ss->ssl3.hs.ws == wait_change_cipher || - ss->ssl3.hs.ws == wait_new_session_ticket) && - ssl3_CanFalseStart(ss); + canFalseStart = ss->ssl3.hs.canFalseStart; ssl_ReleaseSSL3HandshakeLock(ss); } } while (ss->ssl3.hs.ws != idle_handshake && diff --git a/chromium/net/third_party/nss/ssl/sslauth.c b/chromium/net/third_party/nss/ssl/sslauth.c index 8e818decd85..695cab854e4 100644 --- a/chromium/net/third_party/nss/ssl/sslauth.c +++ b/chromium/net/third_party/nss/ssl/sslauth.c @@ -28,38 +28,43 @@ SSL_PeerCertificate(PRFileDesc *fd) } /* NEED LOCKS IN HERE. */ -SECStatus -SSL_PeerCertificateChain(PRFileDesc *fd, CERTCertificate **certs, - unsigned int *numCerts, unsigned int maxNumCerts) +CERTCertList * +SSL_PeerCertificateChain(PRFileDesc *fd) { sslSocket *ss; - ssl3CertNode* cur; + CERTCertList *chain = NULL; + CERTCertificate *cert; + ssl3CertNode *cur; ss = ssl_FindSocket(fd); if (!ss) { SSL_DBG(("%d: SSL[%d]: bad socket in PeerCertificateChain", SSL_GETPID(), fd)); - return SECFailure; + return NULL; } - if (!ss->opt.useSecurity) - return SECFailure; - - if (ss->sec.peerCert == NULL) { - *numCerts = 0; - return SECSuccess; + if (!ss->opt.useSecurity || !ss->sec.peerCert) { + PORT_SetError(SSL_ERROR_NO_CERTIFICATE); + return NULL; + } + chain = CERT_NewCertList(); + if (!chain) { + return NULL; + } + cert = CERT_DupCertificate(ss->sec.peerCert); + if (CERT_AddCertToListTail(chain, cert) != SECSuccess) { + goto loser; } - - *numCerts = 1; /* for the leaf certificate */ - if (maxNumCerts > 0) - certs[0] = CERT_DupCertificate(ss->sec.peerCert); - for (cur = ss->ssl3.peerCertChain; cur; cur = cur->next) { - if (*numCerts < maxNumCerts) - certs[*numCerts] = CERT_DupCertificate(cur->cert); - (*numCerts)++; + cert = CERT_DupCertificate(cur->cert); + if (CERT_AddCertToListTail(chain, cert) != SECSuccess) { + goto loser; + } } + return chain; - return SECSuccess; +loser: + CERT_DestroyCertList(chain); + return NULL; } /* NEED LOCKS IN HERE. */ @@ -95,7 +100,6 @@ SSL_SecurityStatus(PRFileDesc *fd, int *op, char **cp, int *kp0, int *kp1, sslSocket *ss; const char *cipherName; PRBool isDes = PR_FALSE; - PRBool enoughFirstHsDone = PR_FALSE; ss = ssl_FindSocket(fd); if (!ss) { @@ -113,14 +117,7 @@ SSL_SecurityStatus(PRFileDesc *fd, int *op, char **cp, int *kp0, int *kp1, *op = SSL_SECURITY_STATUS_OFF; } - if (ss->firstHsDone) { - enoughFirstHsDone = PR_TRUE; - } else if (ss->version >= SSL_LIBRARY_VERSION_3_0 && - ssl3_CanFalseStart(ss)) { - enoughFirstHsDone = PR_TRUE; - } - - if (ss->opt.useSecurity && enoughFirstHsDone) { + if (ss->opt.useSecurity && ss->enoughFirstHsDone) { if (ss->version < SSL_LIBRARY_VERSION_3_0) { cipherName = ssl_cipherName[ss->sec.cipherType]; } else { diff --git a/chromium/net/third_party/nss/ssl/sslenum.c b/chromium/net/third_party/nss/ssl/sslenum.c index b460f2631dc..fc6b85423b1 100644 --- a/chromium/net/third_party/nss/ssl/sslenum.c +++ b/chromium/net/third_party/nss/ssl/sslenum.c @@ -29,6 +29,16 @@ * Finally, update the ssl_V3_SUITES_IMPLEMENTED macro in sslimpl.h. */ const PRUint16 SSL_ImplementedCiphers[] = { + /* AES-GCM */ +#ifdef NSS_ENABLE_ECC + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +#endif /* NSS_ENABLE_ECC */ + TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_128_GCM_SHA256, + /* 256-bit */ #ifdef NSS_ENABLE_ECC TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, diff --git a/chromium/net/third_party/nss/ssl/sslimpl.h b/chromium/net/third_party/nss/ssl/sslimpl.h index 40d915616b3..614eed145ec 100644 --- a/chromium/net/third_party/nss/ssl/sslimpl.h +++ b/chromium/net/third_party/nss/ssl/sslimpl.h @@ -64,6 +64,8 @@ typedef SSLSignType SSL3SignType; #define calg_aes ssl_calg_aes #define calg_camellia ssl_calg_camellia #define calg_seed ssl_calg_seed +#define calg_aes_gcm ssl_calg_aes_gcm +#define calg_chacha20 ssl_calg_chacha20 #define mac_null ssl_mac_null #define mac_md5 ssl_mac_md5 @@ -71,6 +73,7 @@ typedef SSLSignType SSL3SignType; #define hmac_md5 ssl_hmac_md5 #define hmac_sha ssl_hmac_sha #define hmac_sha256 ssl_hmac_sha256 +#define mac_aead ssl_mac_aead #define SET_ERROR_CODE /* reminder */ #define SEND_ALERT /* reminder */ @@ -290,9 +293,9 @@ typedef struct { } ssl3CipherSuiteCfg; #ifdef NSS_ENABLE_ECC -#define ssl_V3_SUITES_IMPLEMENTED 57 +#define ssl_V3_SUITES_IMPLEMENTED 63 #else -#define ssl_V3_SUITES_IMPLEMENTED 35 +#define ssl_V3_SUITES_IMPLEMENTED 37 #endif /* NSS_ENABLE_ECC */ #define MAX_DTLS_SRTP_CIPHER_SUITES 4 @@ -440,20 +443,6 @@ struct sslGatherStr { #define GS_DATA 3 #define GS_PAD 4 -typedef SECStatus (*SSLCipher)(void * context, - unsigned char * out, - int * outlen, - int maxout, - const unsigned char *in, - int inlen); -typedef SECStatus (*SSLCompressor)(void * context, - unsigned char * out, - int * outlen, - int maxout, - const unsigned char *in, - int inlen); -typedef SECStatus (*SSLDestroy)(void *context, PRBool freeit); - #if defined(NSS_PLATFORM_CLIENT_AUTH) && defined(XP_WIN32) typedef PCERT_KEY_CONTEXT PlatformKey; #elif defined(NSS_PLATFORM_CLIENT_AUTH) && defined(XP_MACOSX) @@ -485,11 +474,13 @@ typedef enum { cipher_camellia_128, cipher_camellia_256, cipher_seed, + cipher_aes_128_gcm, + cipher_chacha20, cipher_missing /* reserved for no such supported cipher */ /* This enum must match ssl3_cipherName[] in ssl3con.c. */ } SSL3BulkCipher; -typedef enum { type_stream, type_block } CipherType; +typedef enum { type_stream, type_block, type_aead } CipherType; #define MAX_IV_LENGTH 24 @@ -531,6 +522,30 @@ typedef struct { PRUint64 cipher_context[MAX_CIPHER_CONTEXT_LLONGS]; } ssl3KeyMaterial; +typedef SECStatus (*SSLCipher)(void * context, + unsigned char * out, + int * outlen, + int maxout, + const unsigned char *in, + int inlen); +typedef SECStatus (*SSLAEADCipher)( + ssl3KeyMaterial * keys, + PRBool doDecrypt, + unsigned char * out, + int * outlen, + int maxout, + const unsigned char *in, + int inlen, + const unsigned char *additionalData, + int additionalDataLen); +typedef SECStatus (*SSLCompressor)(void * context, + unsigned char * out, + int * outlen, + int maxout, + const unsigned char *in, + int inlen); +typedef SECStatus (*SSLDestroy)(void *context, PRBool freeit); + /* The DTLS anti-replay window. Defined here because we need it in * the cipher spec. Note that this is a ring buffer but left and * right represent the true window, with modular arithmetic used to @@ -557,6 +572,7 @@ typedef struct { int mac_size; SSLCipher encode; SSLCipher decode; + SSLAEADCipher aead; SSLDestroy destroy; void * encodeContext; void * decodeContext; @@ -706,8 +722,6 @@ typedef struct { PRBool tls_keygen; } ssl3KEADef; -typedef enum { kg_null, kg_strong, kg_export } SSL3KeyGenMode; - /* ** There are tables of these, all const. */ @@ -719,7 +733,8 @@ struct ssl3BulkCipherDefStr { CipherType type; int iv_size; int block_size; - SSL3KeyGenMode keygen_mode; + int tag_size; /* authentication tag size for AEAD ciphers. */ + int explicit_nonce_size; /* for AEAD ciphers. */ }; /* @@ -866,6 +881,8 @@ const ssl3CipherSuiteDef *suite_def; /* Shared state between ssl3_HandleFinished and ssl3_FinishHandshake */ PRBool cacheSID; + PRBool canFalseStart; /* Can/did we False Start */ + /* clientSigAndHash contains the contents of the signature_algorithms * extension (if any) from the client. This is only valid for TLS 1.2 * or later. */ @@ -1147,6 +1164,10 @@ struct sslSocketStr { unsigned long clientAuthRequested; unsigned long delayDisabled; /* Nagle delay disabled */ unsigned long firstHsDone; /* first handshake is complete. */ + unsigned long enoughFirstHsDone; /* enough of the first handshake is + * done for callbacks to be able to + * retrieve channel security + * parameters from the SSL socket. */ unsigned long handshakeBegun; unsigned long lastWriteBlocked; unsigned long recvdCloseNotify; /* received SSL EOF. */ @@ -1195,6 +1216,8 @@ const unsigned char * preferredCipher; void *badCertArg; SSLHandshakeCallback handshakeCallback; void *handshakeCallbackData; + SSLCanFalseStartCallback canFalseStartCallback; + void *canFalseStartCallbackData; void *pkcs11PinArg; SSLNextProtoCallback nextProtoCallback; void *nextProtoArg; @@ -1408,7 +1431,6 @@ extern void ssl3_SetAlwaysBlock(sslSocket *ss); extern SECStatus ssl_EnableNagleDelay(sslSocket *ss, PRBool enabled); -extern PRBool ssl3_CanFalseStart(sslSocket *ss); extern SECStatus ssl3_CompressMACEncryptRecord(ssl3CipherSpec * cwSpec, PRBool isServer, @@ -1830,9 +1852,7 @@ extern SECStatus ssl_InitSymWrapKeysLock(void); extern SECStatus ssl_FreeSymWrapKeysLock(void); -extern SECStatus ssl_InitSessionCacheLocks(PRBool lazyInit); - -extern SECStatus ssl_FreeSessionCacheLocks(void); +extern SECStatus ssl_InitSessionCacheLocks(void); /***************** platform client auth ****************/ diff --git a/chromium/net/third_party/nss/ssl/sslinfo.c b/chromium/net/third_party/nss/ssl/sslinfo.c index d29fb0d8f72..df7e669c50e 100644 --- a/chromium/net/third_party/nss/ssl/sslinfo.c +++ b/chromium/net/third_party/nss/ssl/sslinfo.c @@ -26,7 +26,6 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len) sslSocket * ss; SSLChannelInfo inf; sslSessionID * sid; - PRBool enoughFirstHsDone = PR_FALSE; if (!info || len < sizeof inf.length) { PORT_SetError(SEC_ERROR_INVALID_ARGS); @@ -43,14 +42,7 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len) memset(&inf, 0, sizeof inf); inf.length = PR_MIN(sizeof inf, len); - if (ss->firstHsDone) { - enoughFirstHsDone = PR_TRUE; - } else if (ss->version >= SSL_LIBRARY_VERSION_3_0 && - ssl3_CanFalseStart(ss)) { - enoughFirstHsDone = PR_TRUE; - } - - if (ss->opt.useSecurity && enoughFirstHsDone) { + if (ss->opt.useSecurity && ss->enoughFirstHsDone) { sid = ss->sec.ci.sid; inf.protocolVersion = ss->version; inf.authKeyBits = ss->sec.authKeyBits; @@ -109,7 +101,7 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len) #define K_ECDHE "ECDHE", kt_ecdh #define C_SEED "SEED", calg_seed -#define C_CAMELLIA "CAMELLIA", calg_camellia +#define C_CAMELLIA "CAMELLIA", calg_camellia #define C_AES "AES", calg_aes #define C_RC4 "RC4", calg_rc4 #define C_RC2 "RC2", calg_rc2 @@ -117,6 +109,8 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len) #define C_3DES "3DES", calg_3des #define C_NULL "NULL", calg_null #define C_SJ "SKIPJACK", calg_sj +#define C_AESGCM "AES-GCM", calg_aes_gcm +#define C_CHACHA20 "CHACHA20POLY1305", calg_chacha20 #define B_256 256, 256, 256 #define B_128 128, 128, 128 @@ -127,12 +121,16 @@ SSL_GetChannelInfo(PRFileDesc *fd, SSLChannelInfo *info, PRUintn len) #define B_40 128, 40, 40 #define B_0 0, 0, 0 +#define M_AEAD_128 "AEAD", ssl_mac_aead, 128 #define M_SHA256 "SHA256", ssl_hmac_sha256, 256 #define M_SHA "SHA1", ssl_mac_sha, 160 #define M_MD5 "MD5", ssl_mac_md5, 128 +#define M_NULL "NULL", ssl_mac_null, 0 static const SSLCipherSuiteInfo suiteInfo[] = { /* <------ Cipher suite --------------------> <auth> <KEA> <bulk cipher> <MAC> <FIPS> */ +{0,CS(TLS_RSA_WITH_AES_128_GCM_SHA256), S_RSA, K_RSA, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, + {0,CS(TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA), S_RSA, K_DHE, C_CAMELLIA, B_256, M_SHA, 0, 0, 0, }, {0,CS(TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA), S_DSA, K_DHE, C_CAMELLIA, B_256, M_SHA, 0, 0, 0, }, {0,CS(TLS_DHE_RSA_WITH_AES_256_CBC_SHA256), S_RSA, K_DHE, C_AES, B_256, M_SHA256, 1, 0, 0, }, @@ -146,6 +144,7 @@ static const SSLCipherSuiteInfo suiteInfo[] = { {0,CS(TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA), S_DSA, K_DHE, C_CAMELLIA, B_128, M_SHA, 0, 0, 0, }, {0,CS(TLS_DHE_DSS_WITH_RC4_128_SHA), S_DSA, K_DHE, C_RC4, B_128, M_SHA, 0, 0, 0, }, {0,CS(TLS_DHE_RSA_WITH_AES_128_CBC_SHA256), S_RSA, K_DHE, C_AES, B_128, M_SHA256, 1, 0, 0, }, +{0,CS(TLS_DHE_RSA_WITH_AES_128_GCM_SHA256), S_RSA, K_DHE, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, {0,CS(TLS_DHE_RSA_WITH_AES_128_CBC_SHA), S_RSA, K_DHE, C_AES, B_128, M_SHA, 1, 0, 0, }, {0,CS(TLS_DHE_DSS_WITH_AES_128_CBC_SHA), S_DSA, K_DHE, C_AES, B_128, M_SHA, 1, 0, 0, }, {0,CS(TLS_RSA_WITH_SEED_CBC_SHA), S_RSA, K_RSA, C_SEED,B_128, M_SHA, 1, 0, 0, }, @@ -175,6 +174,9 @@ static const SSLCipherSuiteInfo suiteInfo[] = { #ifdef NSS_ENABLE_ECC /* ECC cipher suites */ +{0,CS(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), S_RSA, K_ECDHE, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, +{0,CS(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), S_ECDSA, K_ECDHE, C_AESGCM, B_128, M_AEAD_128, 1, 0, 0, }, + {0,CS(TLS_ECDH_ECDSA_WITH_NULL_SHA), S_ECDSA, K_ECDH, C_NULL, B_0, M_SHA, 0, 0, 0, }, {0,CS(TLS_ECDH_ECDSA_WITH_RC4_128_SHA), S_ECDSA, K_ECDH, C_RC4, B_128, M_SHA, 0, 0, 0, }, {0,CS(TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA), S_ECDSA, K_ECDH, C_3DES, B_3DES, M_SHA, 1, 0, 0, }, @@ -187,12 +189,14 @@ static const SSLCipherSuiteInfo suiteInfo[] = { {0,CS(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA), S_ECDSA, K_ECDHE, C_AES, B_128, M_SHA, 1, 0, 0, }, {0,CS(TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), S_ECDSA, K_ECDHE, C_AES, B_128, M_SHA256, 1, 0, 0, }, {0,CS(TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA), S_ECDSA, K_ECDHE, C_AES, B_256, M_SHA, 1, 0, 0, }, +{0,CS(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305),S_ECDSA,K_ECDHE,C_CHACHA20,B_256,M_AEAD_128,0, 0, 0, }, {0,CS(TLS_ECDH_RSA_WITH_NULL_SHA), S_RSA, K_ECDH, C_NULL, B_0, M_SHA, 0, 0, 0, }, {0,CS(TLS_ECDH_RSA_WITH_RC4_128_SHA), S_RSA, K_ECDH, C_RC4, B_128, M_SHA, 0, 0, 0, }, {0,CS(TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA), S_RSA, K_ECDH, C_3DES, B_3DES, M_SHA, 1, 0, 0, }, {0,CS(TLS_ECDH_RSA_WITH_AES_128_CBC_SHA), S_RSA, K_ECDH, C_AES, B_128, M_SHA, 1, 0, 0, }, {0,CS(TLS_ECDH_RSA_WITH_AES_256_CBC_SHA), S_RSA, K_ECDH, C_AES, B_256, M_SHA, 1, 0, 0, }, +{0,CS(TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305), S_RSA,K_ECDHE,C_CHACHA20,B_256,M_AEAD_128, 0, 0, 0, }, {0,CS(TLS_ECDHE_RSA_WITH_NULL_SHA), S_RSA, K_ECDHE, C_NULL, B_0, M_SHA, 0, 0, 0, }, {0,CS(TLS_ECDHE_RSA_WITH_RC4_128_SHA), S_RSA, K_ECDHE, C_RC4, B_128, M_SHA, 0, 0, 0, }, diff --git a/chromium/net/third_party/nss/ssl/sslnonce.c b/chromium/net/third_party/nss/ssl/sslnonce.c index 5d8a95407aa..a6f734948a3 100644 --- a/chromium/net/third_party/nss/ssl/sslnonce.c +++ b/chromium/net/third_party/nss/ssl/sslnonce.c @@ -35,91 +35,55 @@ static PZLock * cacheLock = NULL; #define LOCK_CACHE lock_cache() #define UNLOCK_CACHE PZ_Unlock(cacheLock) -static SECStatus -ssl_InitClientSessionCacheLock(void) -{ - cacheLock = PZ_NewLock(nssILockCache); - return cacheLock ? SECSuccess : SECFailure; -} - -static SECStatus -ssl_FreeClientSessionCacheLock(void) -{ - if (cacheLock) { - PZ_DestroyLock(cacheLock); - cacheLock = NULL; - return SECSuccess; - } - PORT_SetError(SEC_ERROR_NOT_INITIALIZED); - return SECFailure; -} - -static PRBool LocksInitializedEarly = PR_FALSE; - -static SECStatus -FreeSessionCacheLocks() -{ - SECStatus rv1, rv2; - rv1 = ssl_FreeSymWrapKeysLock(); - rv2 = ssl_FreeClientSessionCacheLock(); - if ( (SECSuccess == rv1) && (SECSuccess == rv2) ) { - return SECSuccess; - } - return SECFailure; -} +static PRCallOnceType lockOnce; +/* FreeSessionCacheLocks is a callback from NSS_RegisterShutdown which destroys + * the session cache locks on shutdown and resets them to their initial + * state. */ static SECStatus -InitSessionCacheLocks(void) +FreeSessionCacheLocks(void* appData, void* nssData) { - SECStatus rv1, rv2; - PRErrorCode rc; - rv1 = ssl_InitSymWrapKeysLock(); - rv2 = ssl_InitClientSessionCacheLock(); - if ( (SECSuccess == rv1) && (SECSuccess == rv2) ) { - return SECSuccess; - } - rc = PORT_GetError(); - FreeSessionCacheLocks(); - PORT_SetError(rc); - return SECFailure; -} + static const PRCallOnceType pristineCallOnce; + SECStatus rv; -/* free the session cache locks if they were initialized early */ -SECStatus -ssl_FreeSessionCacheLocks() -{ - PORT_Assert(PR_TRUE == LocksInitializedEarly); - if (!LocksInitializedEarly) { + if (!cacheLock) { PORT_SetError(SEC_ERROR_NOT_INITIALIZED); return SECFailure; } - FreeSessionCacheLocks(); - LocksInitializedEarly = PR_FALSE; - return SECSuccess; -} -static PRCallOnceType lockOnce; + PZ_DestroyLock(cacheLock); + cacheLock = NULL; -/* free the session cache locks if they were initialized lazily */ -static SECStatus ssl_ShutdownLocks(void* appData, void* nssData) -{ - PORT_Assert(PR_FALSE == LocksInitializedEarly); - if (LocksInitializedEarly) { - PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); - return SECFailure; + rv = ssl_FreeSymWrapKeysLock(); + if (rv != SECSuccess) { + return rv; } - FreeSessionCacheLocks(); - memset(&lockOnce, 0, sizeof(lockOnce)); + + lockOnce = pristineCallOnce; return SECSuccess; } -static PRStatus initSessionCacheLocksLazily(void) +/* InitSessionCacheLocks is called, protected by lockOnce, to create the + * session cache locks. */ +static PRStatus +InitSessionCacheLocks(void) { - SECStatus rv = InitSessionCacheLocks(); - if (SECSuccess != rv) { + SECStatus rv; + + cacheLock = PZ_NewLock(nssILockCache); + if (cacheLock == NULL) { + return PR_FAILURE; + } + rv = ssl_InitSymWrapKeysLock(); + if (rv != SECSuccess) { + PRErrorCode error = PORT_GetError(); + PZ_DestroyLock(cacheLock); + cacheLock = NULL; + PORT_SetError(error); return PR_FAILURE; } - rv = NSS_RegisterShutdown(ssl_ShutdownLocks, NULL); + + rv = NSS_RegisterShutdown(FreeSessionCacheLocks, NULL); PORT_Assert(SECSuccess == rv); if (SECSuccess != rv) { return PR_FAILURE; @@ -127,34 +91,18 @@ static PRStatus initSessionCacheLocksLazily(void) return PR_SUCCESS; } -/* lazyInit means that the call is not happening during a 1-time - * initialization function, but rather during dynamic, lazy initialization - */ SECStatus -ssl_InitSessionCacheLocks(PRBool lazyInit) +ssl_InitSessionCacheLocks(void) { - if (LocksInitializedEarly) { - return SECSuccess; - } - - if (lazyInit) { - return (PR_SUCCESS == - PR_CallOnce(&lockOnce, initSessionCacheLocksLazily)) ? - SECSuccess : SECFailure; - } - - if (SECSuccess == InitSessionCacheLocks()) { - LocksInitializedEarly = PR_TRUE; - return SECSuccess; - } - - return SECFailure; + return (PR_SUCCESS == + PR_CallOnce(&lockOnce, InitSessionCacheLocks)) ? + SECSuccess : SECFailure; } -static void +static void lock_cache(void) { - ssl_InitSessionCacheLocks(PR_TRUE); + ssl_InitSessionCacheLocks(); PZ_Lock(cacheLock); } diff --git a/chromium/net/third_party/nss/ssl/sslproto.h b/chromium/net/third_party/nss/ssl/sslproto.h index b037887acf9..6b60a28616f 100644 --- a/chromium/net/third_party/nss/ssl/sslproto.h +++ b/chromium/net/third_party/nss/ssl/sslproto.h @@ -162,6 +162,10 @@ #define TLS_RSA_WITH_SEED_CBC_SHA 0x0096 +#define TLS_RSA_WITH_AES_128_GCM_SHA256 0x009C +#define TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 0x009E +#define TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 0x00A2 + /* TLS "Signaling Cipher Suite Value" (SCSV). May be requested by client. * Must NEVER be chosen by server. SSL 3.0 server acknowledges by sending * back an empty Renegotiation Info (RI) server hello extension. @@ -204,6 +208,14 @@ #define TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 0xC023 #define TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 0xC027 +#define TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 0xC02B +#define TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 0xC02D +#define TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 0xC02F +#define TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 0xC031 + +#define TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 0xCC13 +#define TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 0xCC14 + /* Netscape "experimental" cipher suites. */ #define SSL_RSA_OLDFIPS_WITH_3DES_EDE_CBC_SHA 0xffe0 #define SSL_RSA_OLDFIPS_WITH_DES_CBC_SHA 0xffe1 diff --git a/chromium/net/third_party/nss/ssl/sslsecur.c b/chromium/net/third_party/nss/ssl/sslsecur.c index 0714a0b75bb..6c7532e2552 100644 --- a/chromium/net/third_party/nss/ssl/sslsecur.c +++ b/chromium/net/third_party/nss/ssl/sslsecur.c @@ -99,21 +99,12 @@ ssl_Do1stHandshake(sslSocket *ss) if (ss->handshake == 0) { ssl_GetRecvBufLock(ss); ss->gs.recordLen = 0; + ss->gs.writeOffset = 0; + ss->gs.readOffset = 0; ssl_ReleaseRecvBufLock(ss); SSL_TRC(3, ("%d: SSL[%d]: handshake is completed", SSL_GETPID(), ss->fd)); - /* call handshake callback for ssl v2 */ - /* for v3 this is done in ssl3_HandleFinished() */ - if ((ss->handshakeCallback != NULL) && /* has callback */ - (!ss->firstHsDone) && /* only first time */ - (ss->version < SSL_LIBRARY_VERSION_3_0)) { /* not ssl3 */ - ss->firstHsDone = PR_TRUE; - (ss->handshakeCallback)(ss->fd, ss->handshakeCallbackData); - } - ss->firstHsDone = PR_TRUE; - ss->gs.writeOffset = 0; - ss->gs.readOffset = 0; break; } rv = (*ss->handshake)(ss); @@ -206,6 +197,7 @@ SSL_ResetHandshake(PRFileDesc *s, PRBool asServer) ssl_Get1stHandshakeLock(ss); ss->firstHsDone = PR_FALSE; + ss->enoughFirstHsDone = PR_FALSE; if ( asServer ) { ss->handshake = ssl2_BeginServerHandshake; ss->handshaking = sslHandshakingAsServer; @@ -221,6 +213,8 @@ SSL_ResetHandshake(PRFileDesc *s, PRBool asServer) ssl_ReleaseRecvBufLock(ss); ssl_GetSSL3HandshakeLock(ss); + ss->ssl3.hs.canFalseStart = PR_FALSE; + ss->ssl3.hs.restartTarget = NULL; /* ** Blow away old security state and get a fresh setup. @@ -266,7 +260,7 @@ SSL_ReHandshake(PRFileDesc *fd, PRBool flushCache) /* SSL v2 protocol does not support subsequent handshakes. */ if (ss->version < SSL_LIBRARY_VERSION_3_0) { - PORT_SetError(SEC_ERROR_INVALID_ARGS); + PORT_SetError(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_SSL2); rv = SECFailure; } else { ssl_GetSSL3HandshakeLock(ss); @@ -331,6 +325,75 @@ SSL_HandshakeCallback(PRFileDesc *fd, SSLHandshakeCallback cb, return SECSuccess; } +/* Register an application callback to be called when false start may happen. +** Acquires and releases HandshakeLock. +*/ +SECStatus +SSL_SetCanFalseStartCallback(PRFileDesc *fd, SSLCanFalseStartCallback cb, + void *client_data) +{ + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + SSL_DBG(("%d: SSL[%d]: bad socket in SSL_SetCanFalseStartCallback", + SSL_GETPID(), fd)); + return SECFailure; + } + + if (!ss->opt.useSecurity) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + ssl_Get1stHandshakeLock(ss); + ssl_GetSSL3HandshakeLock(ss); + + ss->canFalseStartCallback = cb; + ss->canFalseStartCallbackData = client_data; + + ssl_ReleaseSSL3HandshakeLock(ss); + ssl_Release1stHandshakeLock(ss); + + return SECSuccess; +} + +/* A utility function that can be called from a custom SSLCanFalseStartCallback +** function to determine what NSS would have done for this connection if the +** custom callback was not implemented. +*/ +SECStatus +SSL_DefaultCanFalseStart(PRFileDesc *fd, PRBool *canFalseStart) +{ + sslSocket *ss; + + *canFalseStart = PR_FALSE; + ss = ssl_FindSocket(fd); + if (!ss) { + SSL_DBG(("%d: SSL[%d]: bad socket in SSL_DefaultCanFalseStart", + SSL_GETPID(), fd)); + return SECFailure; + } + + if (!ss->ssl3.initialized) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + if (ss->version < SSL_LIBRARY_VERSION_3_0) { + PORT_SetError(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_SSL2); + return SECFailure; + } + + /* Require a forward-secret key exchange. */ + *canFalseStart = ss->ssl3.hs.kea_def->kea == kea_dhe_dss || + ss->ssl3.hs.kea_def->kea == kea_dhe_rsa || + ss->ssl3.hs.kea_def->kea == kea_ecdhe_ecdsa || + ss->ssl3.hs.kea_def->kea == kea_ecdhe_rsa; + + return SECSuccess; +} + /* Try to make progress on an SSL handshake by attempting to read the ** next handshake from the peer, and sending any responses. ** For non-blocking sockets, returns PR_ERROR_WOULD_BLOCK if it cannot @@ -1195,12 +1258,7 @@ ssl_SecureSend(sslSocket *ss, const unsigned char *buf, int len, int flags) ssl_Get1stHandshakeLock(ss); if (ss->version >= SSL_LIBRARY_VERSION_3_0) { ssl_GetSSL3HandshakeLock(ss); - if ((ss->ssl3.hs.ws == wait_change_cipher || - ss->ssl3.hs.ws == wait_finished || - ss->ssl3.hs.ws == wait_new_session_ticket) && - ssl3_CanFalseStart(ss)) { - canFalseStart = PR_TRUE; - } + canFalseStart = ss->ssl3.hs.canFalseStart; ssl_ReleaseSSL3HandshakeLock(ss); } if (!canFalseStart && diff --git a/chromium/net/third_party/nss/ssl/sslsnce.c b/chromium/net/third_party/nss/ssl/sslsnce.c index b0446adc17f..34e07b00511 100644 --- a/chromium/net/third_party/nss/ssl/sslsnce.c +++ b/chromium/net/third_party/nss/ssl/sslsnce.c @@ -1353,7 +1353,7 @@ SSL_ConfigServerSessionIDCache( int maxCacheEntries, PRUint32 ssl3_timeout, const char * directory) { - ssl_InitSessionCacheLocks(PR_FALSE); + ssl_InitSessionCacheLocks(); return SSL_ConfigServerSessionIDCacheInstance(&globalCache, maxCacheEntries, ssl2_timeout, ssl3_timeout, directory, PR_FALSE); } @@ -1467,7 +1467,7 @@ SSL_ConfigServerSessionIDCacheWithOpt( PRBool enableMPCache) { if (!enableMPCache) { - ssl_InitSessionCacheLocks(PR_FALSE); + ssl_InitSessionCacheLocks(); return ssl_ConfigServerSessionIDCacheInstanceWithOpt(&globalCache, ssl2_timeout, ssl3_timeout, directory, PR_FALSE, maxCacheEntries, maxCertCacheEntries, maxSrvNameCacheEntries); @@ -1512,7 +1512,7 @@ SSL_InheritMPServerSIDCacheInstance(cacheDesc *cache, const char * envString) return SECSuccess; /* already done. */ } - ssl_InitSessionCacheLocks(PR_FALSE); + ssl_InitSessionCacheLocks(); ssl_sid_lookup = ServerSessionIDLookup; ssl_sid_cache = ServerSessionIDCache; diff --git a/chromium/net/third_party/nss/ssl/sslsock.c b/chromium/net/third_party/nss/ssl/sslsock.c index db0da5f13d1..072fad5ba0b 100644 --- a/chromium/net/third_party/nss/ssl/sslsock.c +++ b/chromium/net/third_party/nss/ssl/sslsock.c @@ -67,8 +67,10 @@ static cipherPolicy ssl_ciphers[] = { /* Export France */ { TLS_DHE_DSS_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_DHE_RSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_RSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_DHE_DSS_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_DHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, @@ -94,7 +96,9 @@ static cipherPolicy ssl_ciphers[] = { /* Export France */ { TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDH_RSA_WITH_NULL_SHA, SSL_ALLOWED, SSL_ALLOWED }, { TLS_ECDH_RSA_WITH_RC4_128_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, @@ -105,7 +109,9 @@ static cipherPolicy ssl_ciphers[] = { /* Export France */ { TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, { TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, + { TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED }, #endif /* NSS_ENABLE_ECC */ { 0, SSL_NOT_ALLOWED, SSL_NOT_ALLOWED } }; @@ -2451,10 +2457,14 @@ ssl_Poll(PRFileDesc *fd, PRInt16 how_flags, PRInt16 *p_out_flags) } else if (new_flags & PR_POLL_WRITE) { /* The caller is trying to write, but the handshake is ** blocked waiting for data to read, and the first - ** handshake has been sent. so do NOT to poll on write. + ** handshake has been sent. So do NOT to poll on write + ** unless we did false start. */ - new_flags ^= PR_POLL_WRITE; /* don't select on write. */ - new_flags |= PR_POLL_READ; /* do select on read. */ + if (!(ss->version >= SSL_LIBRARY_VERSION_3_0 && + ss->ssl3.hs.canFalseStart)) { + new_flags ^= PR_POLL_WRITE; /* don't select on write. */ + } + new_flags |= PR_POLL_READ; /* do select on read. */ } } } else if ((new_flags & PR_POLL_READ) && (SSL_DataPending(fd) > 0)) { diff --git a/chromium/net/third_party/nss/ssl/sslt.h b/chromium/net/third_party/nss/ssl/sslt.h index 41d01130d9e..a8007d8b4cf 100644 --- a/chromium/net/third_party/nss/ssl/sslt.h +++ b/chromium/net/third_party/nss/ssl/sslt.h @@ -91,9 +91,11 @@ typedef enum { ssl_calg_3des = 4, ssl_calg_idea = 5, ssl_calg_fortezza = 6, /* deprecated, now unused */ - ssl_calg_aes = 7, /* coming soon */ + ssl_calg_aes = 7, ssl_calg_camellia = 8, - ssl_calg_seed = 9 + ssl_calg_seed = 9, + ssl_calg_aes_gcm = 10, + ssl_calg_chacha20 = 11 } SSLCipherAlgorithm; typedef enum { @@ -102,7 +104,8 @@ typedef enum { ssl_mac_sha = 2, ssl_hmac_md5 = 3, /* TLS HMAC version of mac_md5 */ ssl_hmac_sha = 4, /* TLS HMAC version of mac_sha */ - ssl_hmac_sha256 = 5 + ssl_hmac_sha256 = 5, + ssl_mac_aead = 6 } SSLMACAlgorithm; typedef enum { @@ -158,6 +161,9 @@ typedef struct SSLCipherSuiteInfoStr { PRUint16 effectiveKeyBits; /* MAC info */ + /* AEAD ciphers don't have a MAC. For an AEAD cipher, macAlgorithmName + * is "AEAD", macAlgorithm is ssl_mac_aead, and macBits is the length in + * bits of the authentication tag. */ const char * macAlgorithmName; SSLMACAlgorithm macAlgorithm; PRUint16 macBits; diff --git a/chromium/net/tools/crl_set_dump/crl_set_dump.cc b/chromium/net/tools/crl_set_dump/crl_set_dump.cc index 191840803b6..6c3bb54bab5 100644 --- a/chromium/net/tools/crl_set_dump/crl_set_dump.cc +++ b/chromium/net/tools/crl_set_dump/crl_set_dump.cc @@ -38,10 +38,10 @@ int main(int argc, char** argv) { output_filename = base::FilePath::FromUTF8Unsafe(argv[3]); std::string crl_set_bytes, delta_bytes; - if (!file_util::ReadFileToString(crl_set_filename, &crl_set_bytes)) + if (!base::ReadFileToString(crl_set_filename, &crl_set_bytes)) return 1; if (!delta_filename.empty() && - !file_util::ReadFileToString(delta_filename, &delta_bytes)) { + !base::ReadFileToString(delta_filename, &delta_bytes)) { return 1; } diff --git a/chromium/net/tools/disk_cache_memory_test/disk_cache_memory_test.cc b/chromium/net/tools/disk_cache_memory_test/disk_cache_memory_test.cc new file mode 100644 index 00000000000..6bb14c221ef --- /dev/null +++ b/chromium/net/tools/disk_cache_memory_test/disk_cache_memory_test.cc @@ -0,0 +1,292 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include <cstdlib> +#include <fstream> +#include <iostream> +#include <string> +#include <vector> + +#include "base/at_exit.h" +#include "base/bind.h" +#include "base/callback.h" +#include "base/command_line.h" +#include "base/files/file_path.h" +#include "base/logging.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_proxy.h" +#include "base/run_loop.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/string_piece.h" +#include "base/strings/string_split.h" +#include "base/strings/stringprintf.h" +#include "net/base/cache_type.h" +#include "net/base/net_errors.h" +#include "net/disk_cache/disk_cache.h" +#include "net/disk_cache/simple/simple_backend_impl.h" +#include "net/disk_cache/simple/simple_index.h" + +namespace disk_cache { +namespace { + +const char kBlockFileBackendType[] = "block_file"; +const char kSimpleBackendType[] = "simple"; + +const char kDiskCacheType[] = "disk_cache"; +const char kAppCacheType[] = "app_cache"; + +const char kPrivateDirty[] = "Private_Dirty:"; +const char kReadWrite[] = "rw-"; +const char kHeap[] = "[heap]"; +const char kKb[] = "kB"; + +struct CacheSpec { + public: + static scoped_ptr<CacheSpec> Parse(const std::string& spec_string) { + std::vector<std::string> tokens; + base::SplitString(spec_string, ':', &tokens); + if (tokens.size() != 3) + return scoped_ptr<CacheSpec>(); + if (tokens[0] != kBlockFileBackendType && tokens[0] != kSimpleBackendType) + return scoped_ptr<CacheSpec>(); + if (tokens[1] != kDiskCacheType && tokens[1] != kAppCacheType) + return scoped_ptr<CacheSpec>(); + return scoped_ptr<CacheSpec>(new CacheSpec( + tokens[0] == kBlockFileBackendType ? net::CACHE_BACKEND_BLOCKFILE + : net::CACHE_BACKEND_SIMPLE, + tokens[1] == kDiskCacheType ? net::DISK_CACHE : net::APP_CACHE, + base::FilePath(tokens[2]))); + } + + const net::BackendType backend_type; + const net::CacheType cache_type; + const base::FilePath path; + + private: + CacheSpec(net::BackendType backend_type, + net::CacheType cache_type, + const base::FilePath& path) + : backend_type(backend_type), + cache_type(cache_type), + path(path) { + } +}; + +void SetSuccessCodeOnCompletion(base::RunLoop* run_loop, + bool* succeeded, + int net_error) { + if (net_error == net::OK) { + *succeeded = true; + } else { + *succeeded = false; + } + run_loop->Quit(); +} + +scoped_ptr<Backend> CreateAndInitBackend(const CacheSpec& spec) { + scoped_ptr<Backend> result; + scoped_ptr<Backend> backend; + bool succeeded = false; + base::RunLoop run_loop; + const net::CompletionCallback callback = base::Bind( + &SetSuccessCodeOnCompletion, + base::Unretained(&run_loop), + base::Unretained(&succeeded)); + const int net_error = CreateCacheBackend( + spec.cache_type, spec.backend_type, spec.path, 0, false, + base::MessageLoopProxy::current(), NULL, &backend, callback); + if (net_error == net::OK) + callback.Run(net::OK); + else + run_loop.Run(); + if (!succeeded) { + LOG(ERROR) << "Could not initialize backend in " + << spec.path.LossyDisplayName(); + return result.Pass(); + } + // For the simple cache, the index may not be initialized yet. + if (spec.backend_type == net::CACHE_BACKEND_SIMPLE) { + base::RunLoop index_run_loop; + const net::CompletionCallback index_callback = base::Bind( + &SetSuccessCodeOnCompletion, + base::Unretained(&index_run_loop), + base::Unretained(&succeeded)); + SimpleBackendImpl* simple_backend = + static_cast<SimpleBackendImpl*>(backend.get()); + const int index_net_error = + simple_backend->index()->ExecuteWhenReady(index_callback); + if (index_net_error == net::OK) + index_callback.Run(net::OK); + else + index_run_loop.Run(); + if (!succeeded) { + LOG(ERROR) << "Could not initialize Simple Cache in " + << spec.path.LossyDisplayName(); + return result.Pass(); + } + } + DCHECK(backend); + result.swap(backend); + return result.Pass(); +} + +// Parses range lines from /proc/<PID>/smaps, e.g. (anonymous read write): +// 7f819d88b000-7f819d890000 rw-p 00000000 00:00 0 +bool ParseRangeLine(const std::string& line, + std::vector<std::string>* tokens, + bool* is_anonymous_read_write) { + tokens->clear(); + base::SplitStringAlongWhitespace(line, tokens); + if (tokens->size() == 5) { + const std::string& mode = (*tokens)[1]; + *is_anonymous_read_write = !mode.compare(0, 3, kReadWrite); + return true; + } + // On Android, most of the memory is allocated in the heap, instead of being + // mapped. + if (tokens->size() == 6) { + const std::string& type = (*tokens)[5]; + *is_anonymous_read_write = (type == kHeap); + return true; + } + return false; +} + +// Parses range property lines from /proc/<PID>/smaps, e.g.: +// Private_Dirty: 16 kB +bool ParseRangeProperty(const std::string& line, + std::vector<std::string>* tokens, + uint64* size, + bool* is_private_dirty) { + tokens->clear(); + base::SplitStringAlongWhitespace(line, tokens); + if (tokens->size() != 3) + return false; + const std::string& type = (*tokens)[0]; + if (type != kPrivateDirty) + return true; + const std::string& unit = (*tokens)[2]; + if (unit != kKb) { + LOG(WARNING) << "Discarding value not in kB: " << line; + return true; + } + const std::string& size_str = (*tokens)[1]; + uint64 map_size = 0; + if (!base::StringToUint64(size_str, &map_size)) + return false; + *is_private_dirty = true; + *size = map_size; + return true; +} + +uint64 GetMemoryConsumption() { + std::ifstream maps_file( + base::StringPrintf("/proc/%d/smaps", getpid()).c_str()); + if (!maps_file.good()) { + LOG(ERROR) << "Could not open smaps file."; + return false; + } + std::string line; + std::vector<std::string> tokens; + uint64 total_size = 0; + if (!std::getline(maps_file, line) || line.empty()) + return total_size; + while (true) { + bool is_anonymous_read_write = false; + if (!ParseRangeLine(line, &tokens, &is_anonymous_read_write)) { + LOG(WARNING) << "Parsing smaps - did not expect line: " << line; + } + if (!std::getline(maps_file, line) || line.empty()) + return total_size; + bool is_private_dirty = false; + uint64 size = 0; + while (ParseRangeProperty(line, &tokens, &size, &is_private_dirty)) { + if (is_anonymous_read_write && is_private_dirty) { + total_size += size; + is_private_dirty = false; + } + if (!std::getline(maps_file, line) || line.empty()) + return total_size; + } + } + return total_size; +} + +bool CacheMemTest(const ScopedVector<CacheSpec>& specs) { + ScopedVector<Backend> backends; + ScopedVector<CacheSpec>::const_iterator it; + for (it = specs.begin(); it != specs.end(); ++it) { + scoped_ptr<Backend> backend = CreateAndInitBackend(**it); + if (!backend) + return false; + std::cout << "Number of entries in " << (*it)->path.LossyDisplayName() + << " : " << backend->GetEntryCount() << std::endl; + backends.push_back(backend.release()); + } + const uint64 memory_consumption = GetMemoryConsumption(); + std::cout << "Private dirty memory: " << memory_consumption << " kB" + << std::endl; + return true; +} + +void PrintUsage(std::ostream* stream) { + *stream << "Usage: disk_cache_mem_test " + << "--spec-1=<spec> " + << "[--spec-2=<spec>]" + << std::endl + << " with <cache_spec>=<backend_type>:<cache_type>:<cache_path>" + << std::endl + << " <backend_type>='block_file'|'simple'" << std::endl + << " <cache_type>='disk_cache'|'app_cache'" << std::endl + << " <cache_path>=file system path" << std::endl; +} + +bool ParseAndStoreSpec(const std::string& spec_str, + ScopedVector<CacheSpec>* specs) { + scoped_ptr<CacheSpec> spec = CacheSpec::Parse(spec_str); + if (!spec) { + PrintUsage(&std::cerr); + return false; + } + specs->push_back(spec.release()); + return true; +} + +bool Main(int argc, char** argv) { + base::AtExitManager at_exit_manager; + base::MessageLoopForIO message_loop; + CommandLine::Init(argc, argv); + const CommandLine& command_line = *CommandLine::ForCurrentProcess(); + if (command_line.HasSwitch("help")) { + PrintUsage(&std::cout); + return true; + } + if ((command_line.GetSwitches().size() != 1 && + command_line.GetSwitches().size() != 2) || + !command_line.HasSwitch("spec-1") || + (command_line.GetSwitches().size() == 2 && + !command_line.HasSwitch("spec-2"))) { + PrintUsage(&std::cerr); + return false; + } + ScopedVector<CacheSpec> specs; + const std::string spec_str_1 = command_line.GetSwitchValueASCII("spec-1"); + if (!ParseAndStoreSpec(spec_str_1, &specs)) + return false; + if (command_line.HasSwitch("spec-2")) { + const std::string spec_str_2 = command_line.GetSwitchValueASCII("spec-2"); + if (!ParseAndStoreSpec(spec_str_2, &specs)) + return false; + } + return CacheMemTest(specs); +} + +} // namespace +} // namespace disk_cache + +int main(int argc, char** argv) { + return !disk_cache::Main(argc, argv); +} diff --git a/chromium/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc b/chromium/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc index bcdb7b72dce..f9caa79df62 100644 --- a/chromium/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc +++ b/chromium/net/tools/dns_fuzz_stub/dns_fuzz_stub.cc @@ -53,7 +53,7 @@ bool ReadTestCase(const char* filename, base::FilePath filepath = base::FilePath::FromUTF8Unsafe(filename); std::string json; - if (!file_util::ReadFileToString(filepath, &json)) { + if (!base::ReadFileToString(filepath, &json)) { LOG(ERROR) << filename << ": couldn't read file."; return false; } diff --git a/chromium/net/tools/fetch/http_listen_socket.cc b/chromium/net/tools/fetch/http_listen_socket.cc index 10b601eeca2..410a0ba55e5 100644 --- a/chromium/net/tools/fetch/http_listen_socket.cc +++ b/chromium/net/tools/fetch/http_listen_socket.cc @@ -4,51 +4,50 @@ #include "net/tools/fetch/http_listen_socket.h" -#include <map> - #include "base/compiler_specific.h" #include "base/logging.h" #include "base/message_loop/message_loop.h" +#include "base/stl_util.h" #include "base/strings/string_number_conversions.h" #include "net/tools/fetch/http_server_request_info.h" #include "net/tools/fetch/http_server_response_info.h" -HttpListenSocket::HttpListenSocket(SocketDescriptor s, +HttpListenSocket::HttpListenSocket(net::SocketDescriptor s, HttpListenSocket::Delegate* delegate) : net::TCPListenSocket(s, this), delegate_(delegate) { } HttpListenSocket::~HttpListenSocket() { + STLDeleteElements(&connections_); } void HttpListenSocket::Accept() { - SocketDescriptor conn = net::TCPListenSocket::AcceptSocket(); - DCHECK_NE(conn, net::TCPListenSocket::kInvalidSocket); - if (conn == net::TCPListenSocket::kInvalidSocket) { + net::SocketDescriptor conn = net::TCPListenSocket::AcceptSocket(); + DCHECK_NE(conn, net::kInvalidSocket); + if (conn == net::kInvalidSocket) { // TODO } else { - scoped_refptr<HttpListenSocket> sock( + scoped_ptr<StreamListenSocket> sock( new HttpListenSocket(conn, delegate_)); - // It's up to the delegate to AddRef if it wants to keep it around. - DidAccept(this, sock.get()); + DidAccept(this, sock.Pass()); } } // static -scoped_refptr<HttpListenSocket> HttpListenSocket::CreateAndListen( +scoped_ptr<HttpListenSocket> HttpListenSocket::CreateAndListen( const std::string& ip, int port, HttpListenSocket::Delegate* delegate) { - SocketDescriptor s = net::TCPListenSocket::CreateAndBind(ip, port); - if (s == net::TCPListenSocket::kInvalidSocket) { + net::SocketDescriptor s = net::TCPListenSocket::CreateAndBind(ip, port); + if (s == net::kInvalidSocket) { // TODO (ibrar): error handling. } else { - scoped_refptr<HttpListenSocket> serv = new HttpListenSocket(s, delegate); + scoped_ptr<HttpListenSocket> serv(new HttpListenSocket(s, delegate)); serv->Listen(); - return serv; + return serv.Pass(); } - return NULL; + return scoped_ptr<HttpListenSocket>(); } // @@ -180,9 +179,10 @@ HttpServerRequestInfo* HttpListenSocket::ParseHeaders() { return NULL; } -void HttpListenSocket::DidAccept(net::StreamListenSocket* server, - net::StreamListenSocket* connection) { - connection->AddRef(); +void HttpListenSocket::DidAccept( + net::StreamListenSocket* server, + scoped_ptr<net::StreamListenSocket> connection) { + connections_.insert(connection.release()); } void HttpListenSocket::DidRead(net::StreamListenSocket* connection, @@ -199,7 +199,9 @@ void HttpListenSocket::DidRead(net::StreamListenSocket* connection, } void HttpListenSocket::DidClose(net::StreamListenSocket* sock) { - sock->Release(); + size_t count = connections_.erase(sock); + DCHECK_EQ(1u, count); + delete sock; } // Convert the numeric status code to a string. diff --git a/chromium/net/tools/fetch/http_listen_socket.h b/chromium/net/tools/fetch/http_listen_socket.h index 379f73cfda7..e0a58c03e2c 100644 --- a/chromium/net/tools/fetch/http_listen_socket.h +++ b/chromium/net/tools/fetch/http_listen_socket.h @@ -5,6 +5,8 @@ #ifndef NET_BASE_TOOLS_HTTP_LISTEN_SOCKET_H_ #define NET_BASE_TOOLS_HTTP_LISTEN_SOCKET_H_ +#include <set> + #include "base/message_loop/message_loop.h" #include "net/socket/stream_listen_socket.h" #include "net/socket/tcp_listen_socket.h" @@ -25,7 +27,9 @@ class HttpListenSocket : public net::TCPListenSocket, virtual ~Delegate() {} }; - static scoped_refptr<HttpListenSocket> CreateAndListen( + virtual ~HttpListenSocket(); + + static scoped_ptr<HttpListenSocket> CreateAndListen( const std::string& ip, int port, HttpListenSocket::Delegate* delegate); // Send a server response. @@ -33,8 +37,9 @@ class HttpListenSocket : public net::TCPListenSocket, void Respond(HttpServerResponseInfo* info, std::string& data); // StreamListenSocket::Delegate. - virtual void DidAccept(net::StreamListenSocket* server, - net::StreamListenSocket* connection) OVERRIDE; + virtual void DidAccept( + net::StreamListenSocket* server, + scoped_ptr<net::StreamListenSocket> connection) OVERRIDE; virtual void DidRead(net::StreamListenSocket* connection, const char* data, int len) OVERRIDE; virtual void DidClose(net::StreamListenSocket* sock) OVERRIDE; @@ -44,13 +49,10 @@ class HttpListenSocket : public net::TCPListenSocket, virtual void Accept() OVERRIDE; private: - friend class base::RefCountedThreadSafe<net::StreamListenSocket>; - static const int kReadBufSize = 16 * 1024; // Must run in the IO thread. - HttpListenSocket(SocketDescriptor s, HttpListenSocket::Delegate* del); - virtual ~HttpListenSocket(); + HttpListenSocket(net::SocketDescriptor s, HttpListenSocket::Delegate* del); // Expects the raw data to be stored in recv_data_. If parsing is successful, // will remove the data parsed from recv_data_, leaving only the unused @@ -60,6 +62,8 @@ class HttpListenSocket : public net::TCPListenSocket, HttpListenSocket::Delegate* const delegate_; std::string recv_data_; + std::set<StreamListenSocket*> connections_; + DISALLOW_COPY_AND_ASSIGN(HttpListenSocket); }; diff --git a/chromium/net/tools/fetch/http_session.h b/chromium/net/tools/fetch/http_session.h index 7d87e05c70b..b0266f2a9f7 100644 --- a/chromium/net/tools/fetch/http_session.h +++ b/chromium/net/tools/fetch/http_session.h @@ -19,7 +19,7 @@ class HttpSession : HttpListenSocket::Delegate { HttpServerRequestInfo* info) OVERRIDE; private: - scoped_refptr<HttpListenSocket> socket_; + scoped_ptr<HttpListenSocket> socket_; DISALLOW_COPY_AND_ASSIGN(HttpSession); }; diff --git a/chromium/net/tools/flip_server/create_listener.cc b/chromium/net/tools/flip_server/create_listener.cc index 4676912c117..21867a918b9 100644 --- a/chromium/net/tools/flip_server/create_listener.cc +++ b/chromium/net/tools/flip_server/create_listener.cc @@ -67,7 +67,7 @@ bool CloseSocket(int *fd, int tries) { //////////////////////////////////////////////////////////////////////////////// // Sets an FD to be nonblocking. -void SetNonBlocking(int fd) { +void FlipSetNonBlocking(int fd) { DCHECK_GE(fd, 0); int fcntl_return = fcntl(fd, F_GETFL, 0); @@ -270,7 +270,7 @@ int CreateConnectedSocket( int *connect_fd, return -1; } - SetNonBlocking( sock ); + FlipSetNonBlocking( sock ); if (disable_nagle) { if (!SetDisableNagle(sock)) { @@ -297,4 +297,3 @@ int CreateConnectedSocket( int *connect_fd, } } // namespace net - diff --git a/chromium/net/tools/flip_server/create_listener.h b/chromium/net/tools/flip_server/create_listener.h index 4a4a6e16cbd..a3f5a87e6a3 100644 --- a/chromium/net/tools/flip_server/create_listener.h +++ b/chromium/net/tools/flip_server/create_listener.h @@ -10,7 +10,7 @@ namespace net { -void SetNonBlocking(int fd); +void FlipSetNonBlocking(int fd); // Summary: // creates a socket for listening, and bind()s and listen()s it. @@ -54,4 +54,3 @@ int CreateConnectedSocket(int *connect_fd, } // namespace net #endif // NET_TOOLS_FLIP_SERVER_CREATE_LISTENER_H__ - diff --git a/chromium/net/tools/flip_server/epoll_server.cc b/chromium/net/tools/flip_server/epoll_server.cc index af09c143b29..0e09a6d7d03 100644 --- a/chromium/net/tools/flip_server/epoll_server.cc +++ b/chromium/net/tools/flip_server/epoll_server.cc @@ -4,11 +4,10 @@ #include "net/tools/flip_server/epoll_server.h" +#include <unistd.h> // For read, pipe, close and write. #include <stdlib.h> // for abort #include <errno.h> // for errno and strerror_r #include <algorithm> -#include <ostream> -#include <unistd.h> // For read, pipe, close and write. #include <utility> #include <vector> @@ -482,7 +481,7 @@ void EpollServer::UnregisterAlarm(const AlarmRegToken& iterator_token) { } int EpollServer::NumFDsRegistered() const { - DCHECK(cb_map_.size() >= 1); + DCHECK_GE(cb_map_.size(), 1u); // Omit the internal FD (read_fd_) return cb_map_.size() - 1; } @@ -490,7 +489,7 @@ int EpollServer::NumFDsRegistered() const { void EpollServer::Wake() { char data = 'd'; // 'd' is for data. It's good enough for me. int rv = write(write_fd_, &data, 1); - DCHECK(rv == 1); + DCHECK_EQ(rv, 1); } int64 EpollServer::NowInUsec() const { @@ -819,4 +818,3 @@ void EpollAlarm::UnregisterIfRegistered() { } } // namespace net - diff --git a/chromium/net/tools/flip_server/flip_config.cc b/chromium/net/tools/flip_server/flip_config.cc index eb5c3caff69..3de302856f1 100644 --- a/chromium/net/tools/flip_server/flip_config.cc +++ b/chromium/net/tools/flip_server/flip_config.cc @@ -72,7 +72,7 @@ FlipAcceptor::FlipAcceptor(enum FlipHandlerType flip_handler_type, } } - SetNonBlocking(listen_fd_); + FlipSetNonBlocking(listen_fd_); VLOG(1) << "Listening on socket: "; if (flip_handler_type == FLIP_HANDLER_PROXY) VLOG(1) << "\tType : Proxy"; diff --git a/chromium/net/tools/flip_server/flip_test_utils.cc b/chromium/net/tools/flip_server/flip_test_utils.cc new file mode 100644 index 00000000000..d9846cf5ac7 --- /dev/null +++ b/chromium/net/tools/flip_server/flip_test_utils.cc @@ -0,0 +1,15 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/flip_server/flip_test_utils.h" + +namespace net { + +MockSMInterface::MockSMInterface() { +} + +MockSMInterface::~MockSMInterface() { +} + +} // namespace net diff --git a/chromium/net/tools/flip_server/flip_test_utils.h b/chromium/net/tools/flip_server/flip_test_utils.h new file mode 100644 index 00000000000..8a10b9563d1 --- /dev/null +++ b/chromium/net/tools/flip_server/flip_test_utils.h @@ -0,0 +1,53 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_TOOLS_FLIP_SERVER_FLIP_TEST_UTILS_H_ +#define NET_TOOLS_FLIP_SERVER_FLIP_TEST_UTILS_H_ + +#include <string> + +#include "net/tools/flip_server/sm_interface.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +class MockSMInterface : public SMInterface { + public: + MockSMInterface(); + virtual ~MockSMInterface(); + + MOCK_METHOD2(InitSMInterface, void(SMInterface*, int32)); + MOCK_METHOD8(InitSMConnection, void(SMConnectionPoolInterface*, + SMInterface*, + EpollServer*, + int, + std::string, + std::string, + std::string, + bool)); + MOCK_METHOD2(ProcessReadInput, size_t(const char*, size_t)); + MOCK_METHOD2(ProcessWriteInput, size_t(const char*, size_t)); + MOCK_METHOD1(SetStreamID, void(uint32 stream_id)); + MOCK_CONST_METHOD0(MessageFullyRead, bool()); + MOCK_CONST_METHOD0(Error, bool()); + MOCK_CONST_METHOD0(ErrorAsString, const char*()); + MOCK_METHOD0(Reset, void()); + MOCK_METHOD1(ResetForNewInterface, void(int32 server_idx)); + MOCK_METHOD0(ResetForNewConnection, void()); + MOCK_METHOD0(Cleanup, void()); + MOCK_METHOD0(PostAcceptHook, int()); + MOCK_METHOD3(NewStream, void(uint32, uint32, const std::string&)); + MOCK_METHOD1(SendEOF, void(uint32 stream_id)); + MOCK_METHOD1(SendErrorNotFound, void(uint32 stream_id)); + MOCK_METHOD2(SendSynStream, size_t(uint32, const BalsaHeaders&)); + MOCK_METHOD2(SendSynReply, size_t(uint32, const BalsaHeaders&)); + MOCK_METHOD5(SendDataFrame, void(uint32, const char*, int64, uint32, bool)); + MOCK_METHOD0(GetOutput, void()); + MOCK_METHOD0(set_is_request, void()); +}; + +} // namespace net + +#endif // NET_TOOLS_FLIP_SERVER_FLIP_TEST_UTILS_H_ diff --git a/chromium/net/tools/flip_server/http_interface.cc b/chromium/net/tools/flip_server/http_interface.cc index 7a44c0313af..916ba51f6ab 100644 --- a/chromium/net/tools/flip_server/http_interface.cc +++ b/chromium/net/tools/flip_server/http_interface.cc @@ -14,11 +14,9 @@ namespace net { HttpSM::HttpSM(SMConnection* connection, SMInterface* sm_spdy_interface, - EpollServer* epoll_server, MemoryCache* memory_cache, FlipAcceptor* acceptor) - : seq_num_(0), - http_framer_(new BalsaFrame), + : http_framer_(new BalsaFrame), stream_id_(0), server_idx_(-1), connection_(connection), @@ -37,7 +35,7 @@ HttpSM::~HttpSM() { delete http_framer_; } -void HttpSM::ProcessBodyData(const char *input, size_t size) { +void HttpSM::ProcessBodyData(const char* input, size_t size) { if (acceptor_->flip_handler_type_ == FLIP_HANDLER_PROXY) { VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Process Body Data: stream " << stream_id_ << ": size " << size; @@ -94,10 +92,6 @@ void HttpSM::AddToOutputOrder(const MemCacheIter& mci) { output_ordering_.AddToOutputOrder(mci); } -void HttpSM::SendOKResponse(uint32 stream_id, std::string* output) { - SendOKResponseImpl(stream_id, output); -} - void HttpSM::InitSMInterface(SMInterface* sm_spdy_interface, int32 server_idx) { sm_spdy_interface_ = sm_spdy_interface; @@ -133,10 +127,10 @@ size_t HttpSM::ProcessReadInput(const char* data, size_t len) { size_t HttpSM::ProcessWriteInput(const char* data, size_t len) { VLOG(2) << ACCEPTOR_CLIENT_IDENT << "HttpSM: Process write input: size " << len << ": stream " << stream_id_; - char * dataPtr = new char[len]; + char* dataPtr = new char[len]; memcpy(dataPtr, data, len); DataFrame* data_frame = new DataFrame; - data_frame->data = (const char *)dataPtr; + data_frame->data = dataPtr; data_frame->size = len; data_frame->delete_when_done = true; connection_->EnqueueDataFrame(data_frame); @@ -178,7 +172,6 @@ void HttpSM::ResetForNewConnection() { << "Sending EOF to spdy."; sm_spdy_interface_->SendEOF(stream_id_); } - seq_num_ = 0; output_ordering_.Reset(); http_framer_->Reset(); if (sm_spdy_interface_) { @@ -257,17 +250,6 @@ void HttpSM::SendErrorNotFoundImpl(uint32 stream_id) { output_ordering_.RemoveStreamId(stream_id); } -void HttpSM::SendOKResponseImpl(uint32 stream_id, std::string* output) { - BalsaHeaders my_headers; - my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "200", "OK"); - my_headers.RemoveAllOfHeader("content-length"); - my_headers.AppendHeader("transfer-encoding", "chunked"); - SendSynReplyImpl(stream_id, my_headers); - SendDataFrame(stream_id, output->c_str(), output->size(), 0, false); - SendEOFImpl(stream_id); - output_ordering_.RemoveStreamId(stream_id); -} - size_t HttpSM::SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers) { SimpleBuffer sb; headers.WriteHeaderAndEndingToBuffer(&sb); diff --git a/chromium/net/tools/flip_server/http_interface.h b/chromium/net/tools/flip_server/http_interface.h index e311881e342..18e616da205 100644 --- a/chromium/net/tools/flip_server/http_interface.h +++ b/chromium/net/tools/flip_server/http_interface.h @@ -2,8 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#ifndef NET_TOOLS_FLIP_SERVER_HTTP_INTERFACE_ -#define NET_TOOLS_FLIP_SERVER_HTTP_INTERFACE_ +#ifndef NET_TOOLS_FLIP_SERVER_HTTP_INTERFACE_H_ +#define NET_TOOLS_FLIP_SERVER_HTTP_INTERFACE_H_ #include <string> @@ -27,7 +27,6 @@ class HttpSM : public BalsaVisitorInterface, public: HttpSM(SMConnection* connection, SMInterface* sm_spdy_interface, - EpollServer* epoll_server, MemoryCache* memory_cache, FlipAcceptor* acceptor); virtual ~HttpSM(); @@ -69,9 +68,9 @@ class HttpSM : public BalsaVisitorInterface, public: void AddToOutputOrder(const MemCacheIter& mci); - void SendOKResponse(uint32 stream_id, std::string* output); BalsaFrame* spdy_framer() { return http_framer_; } virtual void set_is_request() OVERRIDE {} + const OutputOrdering& output_ordering() const { return output_ordering_; } // SMInterface: virtual void InitSMInterface(SMInterface* sm_spdy_interface, @@ -110,7 +109,7 @@ class HttpSM : public BalsaVisitorInterface, private: void SendEOFImpl(uint32 stream_id); void SendErrorNotFoundImpl(uint32 stream_id); - void SendOKResponseImpl(uint32 stream_id, std::string* output); + void SendOKResponseImpl(uint32 stream_id, const std::string& output); size_t SendSynReplyImpl(uint32 stream_id, const BalsaHeaders& headers); size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers); void SendDataFrameImpl(uint32 stream_id, const char* data, int64 len, @@ -119,7 +118,6 @@ class HttpSM : public BalsaVisitorInterface, virtual void GetOutput() OVERRIDE; private: - uint64 seq_num_; BalsaFrame* http_framer_; BalsaHeaders headers_; uint32 stream_id_; @@ -133,7 +131,6 @@ class HttpSM : public BalsaVisitorInterface, FlipAcceptor* acceptor_; }; -} // namespace - -#endif // NET_TOOLS_FLIP_SERVER_HTTP_INTERFACE_ +} // namespace net +#endif // NET_TOOLS_FLIP_SERVER_HTTP_INTERFACE_H_ diff --git a/chromium/net/tools/flip_server/http_interface_test.cc b/chromium/net/tools/flip_server/http_interface_test.cc new file mode 100644 index 00000000000..ba9b3aa36bb --- /dev/null +++ b/chromium/net/tools/flip_server/http_interface_test.cc @@ -0,0 +1,487 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/flip_server/http_interface.h" + +#include <list> + +#include "base/memory/scoped_ptr.h" +#include "base/stl_util.h" +#include "base/strings/string_piece.h" +#include "net/tools/flip_server/balsa_enums.h" +#include "net/tools/flip_server/balsa_frame.h" +#include "net/tools/flip_server/balsa_headers.h" +#include "net/tools/flip_server/flip_config.h" +#include "net/tools/flip_server/flip_test_utils.h" +#include "net/tools/flip_server/mem_cache.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +using ::base::StringPiece; +using ::testing::_; +using ::testing::InSequence; + +namespace { + +class MockSMConnection : public SMConnection { + public: + MockSMConnection(EpollServer* epoll_server, + SSLState* ssl_state, + MemoryCache* memory_cache, + FlipAcceptor* acceptor, + std::string log_prefix) + : SMConnection(epoll_server, + ssl_state, + memory_cache, + acceptor, + log_prefix) {} + + MOCK_METHOD0(Cleanup, void()); + MOCK_METHOD8(InitSMConnection, void(SMConnectionPoolInterface*, + SMInterface*, + EpollServer*, + int, + std::string, + std::string, + std::string, + bool)); +}; + +class FlipHttpSMTest : public ::testing::Test { + public: + explicit FlipHttpSMTest(FlipHandlerType type = FLIP_HANDLER_PROXY) { + SSLState* ssl_state = NULL; + mock_another_interface_.reset(new MockSMInterface); + memory_cache_.reset(new MemoryCache); + acceptor_.reset(new FlipAcceptor(type, + "127.0.0.1", + "8941", + "ssl_cert_filename", + "ssl_key_filename", + "127.0.0.1", + "8942", + "127.0.0.1", + "8943", + 1, + 0, + true, + 1, + false, + true, + NULL)); + epoll_server_.reset(new EpollServer); + connection_.reset(new MockSMConnection(epoll_server_.get(), + ssl_state, + memory_cache_.get(), + acceptor_.get(), + "log_prefix")); + + interface_.reset(new HttpSM(connection_.get(), + mock_another_interface_.get(), + memory_cache_.get(), + acceptor_.get())); + } + + virtual void TearDown() OVERRIDE { + if (acceptor_->listen_fd_ >= 0) { + epoll_server_->UnregisterFD(acceptor_->listen_fd_); + close(acceptor_->listen_fd_); + acceptor_->listen_fd_ = -1; + } + STLDeleteElements(connection_->output_list()); + } + + bool HasStream(uint32 stream_id) { + return interface_->output_ordering().ExistsInPriorityMaps(stream_id); + } + + protected: + scoped_ptr<MockSMInterface> mock_another_interface_; + scoped_ptr<MemoryCache> memory_cache_; + scoped_ptr<FlipAcceptor> acceptor_; + scoped_ptr<EpollServer> epoll_server_; + scoped_ptr<MockSMConnection> connection_; + scoped_ptr<HttpSM> interface_; +}; + +class FlipHttpSMProxyTest : public FlipHttpSMTest { + public: + FlipHttpSMProxyTest() : FlipHttpSMTest(FLIP_HANDLER_PROXY) {} + virtual ~FlipHttpSMProxyTest() {} +}; + +class FlipHttpSMHttpTest : public FlipHttpSMTest { + public: + FlipHttpSMHttpTest() : FlipHttpSMTest(FLIP_HANDLER_HTTP_SERVER) {} + virtual ~FlipHttpSMHttpTest() {} +}; + +class FlipHttpSMSpdyTest : public FlipHttpSMTest { + public: + FlipHttpSMSpdyTest() : FlipHttpSMTest(FLIP_HANDLER_SPDY_SERVER) {} + virtual ~FlipHttpSMSpdyTest() {} +}; + +TEST_F(FlipHttpSMTest, Construct) { + ASSERT_FALSE(interface_->spdy_framer()->is_request()); +} + +TEST_F(FlipHttpSMTest, AddToOutputOrder) { + uint32 stream_id = 13; + MemCacheIter mci; + mci.stream_id = stream_id; + + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + mci.file_data = memory_cache_->GetFileData(filename); + } + + interface_->AddToOutputOrder(mci); + ASSERT_TRUE(HasStream(stream_id)); +} + +TEST_F(FlipHttpSMTest, InitSMInterface) { + scoped_ptr<MockSMInterface> mock(new MockSMInterface); + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendEOF(_)); + EXPECT_CALL(*mock_another_interface_, ResetForNewInterface(_)); + EXPECT_CALL(*mock, SendEOF(_)); + EXPECT_CALL(*mock, ResetForNewInterface(_)); + } + + interface_->ResetForNewConnection(); + interface_->InitSMInterface(mock.get(), 0); + interface_->ResetForNewConnection(); +} + +TEST_F(FlipHttpSMTest, InitSMConnection) { + EXPECT_CALL(*connection_, InitSMConnection(_, _, _, _, _, _, _, _)); + + interface_->InitSMConnection(NULL, NULL, NULL, 0, "", "", "", false); +} + +TEST_F(FlipHttpSMTest, ProcessReadInput) { + std::string data = "HTTP/1.1 200 OK\r\n" + "Content-Length: 14\r\n\r\n" + "hello, world\r\n"; + testing::MockFunction<void(int)> checkpoint; + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendSynReply(_, _)); + EXPECT_CALL(checkpoint, Call(0)); + EXPECT_CALL(*mock_another_interface_, SendDataFrame(_, _, _, _, _)); + EXPECT_CALL(*mock_another_interface_, SendEOF(_)); + } + + ASSERT_EQ(BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE, + interface_->spdy_framer()->ParseState()); + + size_t read = interface_->ProcessReadInput(data.data(), data.size()); + ASSERT_EQ(39u, read); + checkpoint.Call(0); + read += interface_->ProcessReadInput(&data.data()[read], data.size() - read); + ASSERT_EQ(data.size(), read); + ASSERT_EQ(BalsaFrameEnums::MESSAGE_FULLY_READ, + interface_->spdy_framer()->ParseState()); + ASSERT_TRUE(interface_->MessageFullyRead()); +} + +TEST_F(FlipHttpSMTest, ProcessWriteInput) { + std::string data = "hello, world"; + interface_->ProcessWriteInput(data.data(), data.size()); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + ASSERT_EQ(data, StringPiece(df->data, df->size)); + ASSERT_EQ(connection_->output_list()->end(), i); +} + +TEST_F(FlipHttpSMTest, Reset) { + std::string data = "HTTP/1.1 200 OK\r\n\r\n"; + testing::MockFunction<void(int)> checkpoint; + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendSynReply(_, _)); + EXPECT_CALL(checkpoint, Call(0)); + } + + ASSERT_EQ(BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE, + interface_->spdy_framer()->ParseState()); + + interface_->ProcessReadInput(data.data(), data.size()); + checkpoint.Call(0); + ASSERT_FALSE(interface_->MessageFullyRead()); + ASSERT_EQ(BalsaFrameEnums::READING_UNTIL_CLOSE, + interface_->spdy_framer()->ParseState()); + + interface_->Reset(); + ASSERT_EQ(BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE, + interface_->spdy_framer()->ParseState()); +} + +TEST_F(FlipHttpSMTest, ResetForNewConnection) { + std::string data = "HTTP/1.1 200 OK\r\n\r\n"; + testing::MockFunction<void(int)> checkpoint; + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendSynReply(_, _)); + EXPECT_CALL(checkpoint, Call(0)); + EXPECT_CALL(*mock_another_interface_, SendEOF(_)); + EXPECT_CALL(*mock_another_interface_, ResetForNewInterface(_)); + } + + ASSERT_EQ(BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE, + interface_->spdy_framer()->ParseState()); + + interface_->ProcessReadInput(data.data(), data.size()); + checkpoint.Call(0); + ASSERT_FALSE(interface_->MessageFullyRead()); + ASSERT_EQ(BalsaFrameEnums::READING_UNTIL_CLOSE, + interface_->spdy_framer()->ParseState()); + + interface_->ResetForNewConnection(); + ASSERT_EQ(BalsaFrameEnums::READING_HEADER_AND_FIRSTLINE, + interface_->spdy_framer()->ParseState()); +} + +TEST_F(FlipHttpSMTest, NewStream) { + uint32 stream_id = 4; + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + } + + interface_->NewStream(stream_id, 1, "foobar"); + ASSERT_TRUE(HasStream(stream_id)); +} + +TEST_F(FlipHttpSMTest, NewStreamError) { + std::string syn_reply = "HTTP/1.1 404 Not Found\r\n" + "transfer-encoding: chunked\r\n\r\n"; + std::string body = "e\r\npage not found\r\n"; + uint32 stream_id = 4; + + ASSERT_FALSE(HasStream(stream_id)); + interface_->NewStream(stream_id, 1, "foobar"); + + ASSERT_EQ(3u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + ASSERT_EQ(syn_reply, StringPiece(df->data, df->size)); + df = *i++; + ASSERT_EQ(body, StringPiece(df->data, df->size)); + df = *i++; + ASSERT_EQ("0\r\n\r\n", StringPiece(df->data, df->size)); + ASSERT_FALSE(HasStream(stream_id)); +} + +TEST_F(FlipHttpSMTest, SendErrorNotFound) { + std::string syn_reply = "HTTP/1.1 404 Not Found\r\n" + "transfer-encoding: chunked\r\n\r\n"; + std::string body = "e\r\npage not found\r\n"; + uint32 stream_id = 13; + MemCacheIter mci; + mci.stream_id = stream_id; + + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + mci.file_data = memory_cache_->GetFileData(filename); + } + + interface_->AddToOutputOrder(mci); + ASSERT_TRUE(HasStream(stream_id)); + interface_->SendErrorNotFound(stream_id); + + ASSERT_EQ(3u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + ASSERT_EQ(syn_reply, StringPiece(df->data, df->size)); + df = *i++; + ASSERT_EQ(body, StringPiece(df->data, df->size)); + df = *i++; + ASSERT_EQ("0\r\n\r\n", StringPiece(df->data, df->size)); + ASSERT_FALSE(HasStream(stream_id)); +} + +TEST_F(FlipHttpSMTest, SendSynStream) { + std::string expected = "GET / HTTP/1.0\r\n" + "key1: value1\r\n\r\n"; + BalsaHeaders headers; + headers.SetResponseFirstlineFromStringPieces("GET", "/path", "HTTP/1.0"); + headers.AppendHeader("key1", "value1"); + interface_->SendSynStream(18, headers); + + // TODO(yhirano): Is this behavior correct? + ASSERT_EQ(0u, connection_->output_list()->size()); +} + +TEST_F(FlipHttpSMTest, SendSynReply) { + std::string expected = "HTTP/1.1 200 OK\r\n" + "key1: value1\r\n\r\n"; + BalsaHeaders headers; + headers.SetResponseFirstlineFromStringPieces("HTTP/1.1", "200", "OK"); + headers.AppendHeader("key1", "value1"); + interface_->SendSynReply(18, headers); + + ASSERT_EQ(1u, connection_->output_list()->size()); + DataFrame* df = connection_->output_list()->front(); + ASSERT_EQ(expected, StringPiece(df->data, df->size)); +} + +TEST_F(FlipHttpSMTest, SendDataFrame) { + std::string data = "foo bar baz"; + interface_->SendDataFrame(12, data.data(), data.size(), 0, false); + + ASSERT_EQ(1u, connection_->output_list()->size()); + DataFrame* df = connection_->output_list()->front(); + ASSERT_EQ("b\r\nfoo bar baz\r\n", StringPiece(df->data, df->size)); +} + +TEST_F(FlipHttpSMProxyTest, ProcessBodyData) { + BalsaVisitorInterface* visitor = interface_.get(); + std::string data = "hello, world"; + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, + SendDataFrame(0, data.data(), data.size(), 0, false)); + } + visitor->ProcessBodyData(data.data(), data.size()); +} + +// -- +// FlipHttpSMProxyTest + +TEST_F(FlipHttpSMProxyTest, ProcessHeaders) { + BalsaVisitorInterface* visitor = interface_.get(); + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendSynReply(0, _)); + } + BalsaHeaders headers; + visitor->ProcessHeaders(headers); +} + +TEST_F(FlipHttpSMProxyTest, MessageDone) { + BalsaVisitorInterface* visitor = interface_.get(); + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendEOF(0)); + } + visitor->MessageDone(); +} + +TEST_F(FlipHttpSMProxyTest, Cleanup) { + EXPECT_CALL(*connection_, Cleanup()).Times(0); + interface_->Cleanup(); +} + +TEST_F(FlipHttpSMProxyTest, SendEOF) { + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, ResetForNewInterface(_)); + } + interface_->SendEOF(32); + ASSERT_EQ(1u, connection_->output_list()->size()); + DataFrame* df = connection_->output_list()->front(); + ASSERT_EQ("0\r\n\r\n", StringPiece(df->data, df->size)); +} + +// -- +// FlipHttpSMHttpTest + +TEST_F(FlipHttpSMHttpTest, ProcessHeaders) { + BalsaVisitorInterface* visitor = interface_.get(); + { + BalsaHeaders headers; + std::string filename = "GET_/path/file"; + memory_cache_->InsertFile(&headers, filename, ""); + } + + BalsaHeaders headers; + headers.AppendHeader("Host", "example.com"); + headers.SetRequestFirstlineFromStringPieces("GET", + "/path/file", + "HTTP/1.0"); + uint32 stream_id = 133; + interface_->SetStreamID(stream_id); + ASSERT_FALSE(HasStream(stream_id)); + visitor->ProcessHeaders(headers); + ASSERT_TRUE(HasStream(stream_id)); +} + +TEST_F(FlipHttpSMHttpTest, MessageDone) { + BalsaVisitorInterface* visitor = interface_.get(); + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendEOF(0)).Times(0); + } + visitor->MessageDone(); +} + +TEST_F(FlipHttpSMHttpTest, Cleanup) { + EXPECT_CALL(*connection_, Cleanup()).Times(0); + interface_->Cleanup(); +} + +TEST_F(FlipHttpSMHttpTest, SendEOF) { + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, ResetForNewInterface(_)).Times(0); + } + interface_->SendEOF(32); + ASSERT_EQ(1u, connection_->output_list()->size()); + DataFrame* df = connection_->output_list()->front(); + ASSERT_EQ("0\r\n\r\n", StringPiece(df->data, df->size)); +} + +// -- +// FlipHttpSMSpdyTest + +TEST_F(FlipHttpSMSpdyTest, ProcessHeaders) { + BalsaVisitorInterface* visitor = interface_.get(); + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendSynReply(0, _)); + } + BalsaHeaders headers; + visitor->ProcessHeaders(headers); +} + +TEST_F(FlipHttpSMSpdyTest, MessageDone) { + BalsaVisitorInterface* visitor = interface_.get(); + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, SendEOF(0)).Times(0); + } + visitor->MessageDone(); +} + +TEST_F(FlipHttpSMSpdyTest, Cleanup) { + EXPECT_CALL(*connection_, Cleanup()).Times(0); + interface_->Cleanup(); +} + +TEST_F(FlipHttpSMSpdyTest, SendEOF) { + { + InSequence s; + EXPECT_CALL(*mock_another_interface_, ResetForNewInterface(_)).Times(0); + } + interface_->SendEOF(32); + ASSERT_EQ(1u, connection_->output_list()->size()); + DataFrame* df = connection_->output_list()->front(); + ASSERT_EQ("0\r\n\r\n", StringPiece(df->data, df->size)); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/tools/flip_server/mem_cache.cc b/chromium/net/tools/flip_server/mem_cache.cc index 924920825c3..d1e0e58be41 100644 --- a/chromium/net/tools/flip_server/mem_cache.cc +++ b/chromium/net/tools/flip_server/mem_cache.cc @@ -202,17 +202,9 @@ void MemoryCache::ReadAndStoreFileContents(const char* filename) { if (slash_pos == std::string::npos) { slash_pos = filename_stripped.size(); } - FileData* data = - new FileData(&visitor.headers, - filename_stripped.substr(0, slash_pos), - visitor.body); - Files::iterator it = files_.find(filename_stripped); - if (it != files_.end()) { - delete it->second; - it->second = data; - } else { - files_.insert(std::make_pair(filename_stripped, data)); - } + InsertFile(&visitor.headers, + filename_stripped.substr(0, slash_pos), + visitor.body); } FileData* MemoryCache::GetFileData(const std::string& filename) { @@ -239,6 +231,22 @@ bool MemoryCache::AssignFileData(const std::string& filename, return true; } +void MemoryCache::InsertFile(const BalsaHeaders* headers, + const std::string& filename, + const std::string& body) { + InsertFile(new FileData(headers, filename, body)); +} + +void MemoryCache::InsertFile(FileData* file_data) { + Files::iterator it = files_.find(file_data->filename()); + if (it != files_.end()) { + delete it->second; + it->second = file_data; + } else { + files_.insert(std::make_pair(file_data->filename(), file_data)); + } +} + void MemoryCache::ClearFiles() { for (Files::const_iterator i = files_.begin(); i != files_.end(); ++i) { delete i->second; diff --git a/chromium/net/tools/flip_server/mem_cache.h b/chromium/net/tools/flip_server/mem_cache.h index 806ae53689b..300c84a3fe0 100644 --- a/chromium/net/tools/flip_server/mem_cache.h +++ b/chromium/net/tools/flip_server/mem_cache.h @@ -135,7 +135,13 @@ class MemoryCache { bool AssignFileData(const std::string& filename, MemCacheIter* mci); + // For unittests + void InsertFile(const BalsaHeaders* headers, + const std::string& filename, + const std::string& body); + private: + void InsertFile(FileData* file_data); void ClearFiles(); Files files_; diff --git a/chromium/net/tools/flip_server/output_ordering.cc b/chromium/net/tools/flip_server/output_ordering.cc index 22fd08ac1c1..6a42869bc8e 100644 --- a/chromium/net/tools/flip_server/output_ordering.cc +++ b/chromium/net/tools/flip_server/output_ordering.cc @@ -4,6 +4,8 @@ #include "net/tools/flip_server/output_ordering.h" +#include <utility> + #include "net/tools/flip_server/flip_config.h" #include "net/tools/flip_server/sm_connection.h" @@ -25,7 +27,9 @@ OutputOrdering::OutputOrdering(SMConnectionInterface* connection) epoll_server_ = connection->epoll_server(); } -OutputOrdering::~OutputOrdering() {} +OutputOrdering::~OutputOrdering() { + Reset(); +} void OutputOrdering::Reset() { while (!stream_ids_.empty()) { @@ -40,8 +44,8 @@ void OutputOrdering::Reset() { first_data_senders_.clear(); } -bool OutputOrdering::ExistsInPriorityMaps(uint32 stream_id) { - StreamIdToPriorityMap::iterator sitpmi = stream_ids_.find(stream_id); +bool OutputOrdering::ExistsInPriorityMaps(uint32 stream_id) const { + StreamIdToPriorityMap::const_iterator sitpmi = stream_ids_.find(stream_id); return sitpmi != stream_ids_.end(); } @@ -78,6 +82,7 @@ void OutputOrdering::BeginOutputtingAlarm::OnRegistration( void OutputOrdering::BeginOutputtingAlarm::OnUnregistration() { pmp_->alarm_enabled = false; + delete this; } void OutputOrdering::BeginOutputtingAlarm::OnShutdown(EpollServer* eps) { diff --git a/chromium/net/tools/flip_server/output_ordering.h b/chromium/net/tools/flip_server/output_ordering.h index 0558e3e4e8d..922d03fc864 100644 --- a/chromium/net/tools/flip_server/output_ordering.h +++ b/chromium/net/tools/flip_server/output_ordering.h @@ -45,7 +45,7 @@ class OutputOrdering { explicit OutputOrdering(SMConnectionInterface* connection); ~OutputOrdering(); void Reset(); - bool ExistsInPriorityMaps(uint32 stream_id); + bool ExistsInPriorityMaps(uint32 stream_id) const; struct BeginOutputtingAlarm : public EpollAlarmCallbackInterface { public: diff --git a/chromium/net/tools/flip_server/sm_connection.cc b/chromium/net/tools/flip_server/sm_connection.cc index 71e81a08ae9..375bc2ec79d 100644 --- a/chromium/net/tools/flip_server/sm_connection.cc +++ b/chromium/net/tools/flip_server/sm_connection.cc @@ -9,6 +9,7 @@ #include <sys/socket.h> #include <unistd.h> +#include <algorithm> #include <list> #include <string> @@ -357,7 +358,6 @@ bool SMConnection::SetupProtocolInterfaces() { if (!sm_http_interface_) sm_http_interface_ = new HttpSM(this, NULL, - epoll_server_, memory_cache_, acceptor_); sm_interface_ = sm_http_interface_; diff --git a/chromium/net/tools/flip_server/sm_connection.h b/chromium/net/tools/flip_server/sm_connection.h index e7f1e99bfa8..3e217729a52 100644 --- a/chromium/net/tools/flip_server/sm_connection.h +++ b/chromium/net/tools/flip_server/sm_connection.h @@ -65,14 +65,14 @@ class SMConnection : public SMConnectionInterface, bool initialized() const { return initialized_; } std::string client_ip() const { return client_ip_; } - void InitSMConnection(SMConnectionPoolInterface* connection_pool, - SMInterface* sm_interface, - EpollServer* epoll_server, - int fd, - std::string server_ip, - std::string server_port, - std::string remote_ip, - bool use_ssl); + virtual void InitSMConnection(SMConnectionPoolInterface* connection_pool, + SMInterface* sm_interface, + EpollServer* epoll_server, + int fd, + std::string server_ip, + std::string server_port, + std::string remote_ip, + bool use_ssl); void CorkSocket(); void UncorkSocket(); @@ -119,12 +119,12 @@ class SMConnection : public SMConnectionInterface, return os; } - private: SMConnection(EpollServer* epoll_server, SSLState* ssl_state, MemoryCache* memory_cache, FlipAcceptor* acceptor, std::string log_prefix); + private: int fd_; int events_; diff --git a/chromium/net/tools/flip_server/spdy_interface.cc b/chromium/net/tools/flip_server/spdy_interface.cc index b4e36bed229..359eab810a7 100644 --- a/chromium/net/tools/flip_server/spdy_interface.cc +++ b/chromium/net/tools/flip_server/spdy_interface.cc @@ -22,7 +22,7 @@ std::string SpdySM::forward_ip_header_; class SpdyFrameDataFrame : public DataFrame { public: - SpdyFrameDataFrame(SpdyFrame* spdy_frame) + explicit SpdyFrameDataFrame(SpdyFrame* spdy_frame) : frame(spdy_frame) { data = spdy_frame->data(); size = spdy_frame->size(); @@ -86,14 +86,14 @@ SMInterface* SpdySM::NewConnectionInterface() { VLOG(2) << ACCEPTOR_CLIENT_IDENT << "SpdySM: Creating new HTTP interface"; SMInterface *sm_http_interface = new HttpSM(server_connection, this, - epoll_server_, memory_cache_, acceptor_); return sm_http_interface; } SMInterface* SpdySM::FindOrMakeNewSMConnectionInterface( - std::string server_ip, std::string server_port) { + const std::string& server_ip, + const std::string& server_port) { SMInterface *sm_http_interface; int32 server_idx; if (unused_server_interface_list.empty()) { @@ -154,7 +154,7 @@ int SpdySM::SpdyHandleNewStream( // UrlUtilities::GetUrlPath will fail and always return a / breaking // the request. GetUrlPath assumes the absolute URL is being passed in. std::string uri; - if (url->second.compare(0,4,"http") == 0) + if (url->second.compare(0, 4, "http") == 0) uri = UrlUtilities::GetUrlPath(url->second); else uri = std::string(url->second); @@ -323,6 +323,9 @@ void SpdySM::NewStream(uint32 stream_id, MemCacheIter mci; mci.stream_id = stream_id; mci.priority = priority; + // TODO(yhirano): The program will crash when + // acceptor_->flip_handler_type_ != FLIP_HANDLER_SPDY_SERVER. + // It should be fixed or an assertion should be placed. if (acceptor_->flip_handler_type_ == FLIP_HANDLER_SPDY_SERVER) { if (!memory_cache_->AssignFileData(filename, &mci)) { // error creating new stream. @@ -348,10 +351,6 @@ void SpdySM::SendErrorNotFound(uint32 stream_id) { SendErrorNotFoundImpl(stream_id); } -void SpdySM::SendOKResponse(uint32 stream_id, std::string* output) { - SendOKResponseImpl(stream_id, output); -} - size_t SpdySM::SendSynStream(uint32 stream_id, const BalsaHeaders& headers) { return SendSynStreamImpl(stream_id, headers); } @@ -381,15 +380,6 @@ void SpdySM::SendErrorNotFoundImpl(uint32 stream_id) { client_output_ordering_.RemoveStreamId(stream_id); } -void SpdySM::SendOKResponseImpl(uint32 stream_id, std::string* output) { - BalsaHeaders my_headers; - my_headers.SetFirstlineFromStringPieces("HTTP/1.1", "200", "OK"); - SendSynReplyImpl(stream_id, my_headers); - SendDataFrame( - stream_id, output->c_str(), output->size(), DATA_FLAG_FIN, false); - client_output_ordering_.RemoveStreamId(stream_id); -} - void SpdySM::KillStream(uint32 stream_id) { client_output_ordering_.RemoveStreamId(stream_id); } @@ -426,15 +416,13 @@ size_t SpdySM::SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers) { SpdyHeaderBlock block; block["method"] = headers.request_method().as_string(); - if (!headers.HasHeader("status")) - block["status"] = headers.response_code().as_string(); if (!headers.HasHeader("version")) - block["version"] =headers.response_version().as_string(); + block["version"] =headers.request_version().as_string(); if (headers.HasHeader("X-Original-Url")) { std::string original_url = headers.GetHeader("X-Original-Url").as_string(); - block["path"] = UrlUtilities::GetUrlPath(original_url); + block["url"] = UrlUtilities::GetUrlPath(original_url); } else { - block["path"] = headers.request_uri().as_string(); + block["url"] = headers.request_uri().as_string(); } CopyHeaders(block, headers); diff --git a/chromium/net/tools/flip_server/spdy_interface.h b/chromium/net/tools/flip_server/spdy_interface.h index a83dab7f406..3ce1f3bd0aa 100644 --- a/chromium/net/tools/flip_server/spdy_interface.h +++ b/chromium/net/tools/flip_server/spdy_interface.h @@ -48,8 +48,10 @@ class SpdySM : public BufferedSpdyFramerVisitorInterface, private: virtual void set_is_request() OVERRIDE {} SMInterface* NewConnectionInterface(); - SMInterface* FindOrMakeNewSMConnectionInterface(std::string server_ip, - std::string server_port); + // virtual for tests + virtual SMInterface* FindOrMakeNewSMConnectionInterface( + const std::string& server_ip, + const std::string& server_port); int SpdyHandleNewStream(SpdyStreamId stream_id, SpdyPriority priority, const SpdyHeaderBlock& headers, @@ -143,26 +145,26 @@ class SpdySM : public BufferedSpdyFramerVisitorInterface, void AddToOutputOrder(const MemCacheIter& mci); virtual void SendEOF(uint32 stream_id) OVERRIDE; virtual void SendErrorNotFound(uint32 stream_id) OVERRIDE; - void SendOKResponse(uint32 stream_id, std::string* output); virtual size_t SendSynStream(uint32 stream_id, const BalsaHeaders& headers) OVERRIDE; virtual size_t SendSynReply(uint32 stream_id, const BalsaHeaders& headers) OVERRIDE; virtual void SendDataFrame(uint32 stream_id, const char* data, int64 len, uint32 flags, bool compress) OVERRIDE; - BufferedSpdyFramer* spdy_framer() { - return buffered_spdy_framer_; + BufferedSpdyFramer* spdy_framer() { return buffered_spdy_framer_; } + + const OutputOrdering& output_ordering() const { + return client_output_ordering_; } static std::string forward_ip_header() { return forward_ip_header_; } - static void set_forward_ip_header(std::string value) { + static void set_forward_ip_header(const std::string& value) { forward_ip_header_ = value; } private: void SendEOFImpl(uint32 stream_id); void SendErrorNotFoundImpl(uint32 stream_id); - void SendOKResponseImpl(uint32 stream_id, std::string* output); void KillStream(uint32 stream_id); void CopyHeaders(SpdyHeaderBlock& dest, const BalsaHeaders& headers); size_t SendSynStreamImpl(uint32 stream_id, const BalsaHeaders& headers); @@ -171,6 +173,7 @@ class SpdySM : public BufferedSpdyFramerVisitorInterface, SpdyDataFlags flags, bool compress); void EnqueueDataFrame(DataFrame* df); virtual void GetOutput() OVERRIDE; + private: BufferedSpdyFramer* buffered_spdy_framer_; bool valid_spdy_session_; // True if we have seen valid data on this session. diff --git a/chromium/net/tools/flip_server/spdy_interface_test.cc b/chromium/net/tools/flip_server/spdy_interface_test.cc new file mode 100644 index 00000000000..7a1c6e97e3b --- /dev/null +++ b/chromium/net/tools/flip_server/spdy_interface_test.cc @@ -0,0 +1,634 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/flip_server/spdy_interface.h" + +#include <list> + +#include "base/memory/scoped_ptr.h" +#include "base/strings/string_piece.h" +#include "net/spdy/buffered_spdy_framer.h" +#include "net/tools/flip_server/balsa_enums.h" +#include "net/tools/flip_server/balsa_headers.h" +#include "net/tools/flip_server/flip_config.h" +#include "net/tools/flip_server/flip_test_utils.h" +#include "net/tools/flip_server/mem_cache.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +using ::base::StringPiece; +using ::testing::_; +using ::testing::InSequence; +using ::testing::InvokeWithoutArgs; +using ::testing::Return; +using ::testing::SaveArg; + +namespace { + +struct StringSaver { + public: + StringSaver() : data(NULL), size(0) {} + void Save() { + string = std::string(data, size); + } + + const char* data; + size_t size; + std::string string; +}; + +class SpdyFramerVisitor : public BufferedSpdyFramerVisitorInterface { + public: + virtual ~SpdyFramerVisitor() {} + MOCK_METHOD1(OnError, void(SpdyFramer::SpdyError)); + MOCK_METHOD2(OnStreamError, void(SpdyStreamId, const std::string&)); + MOCK_METHOD7(OnSynStream, void(SpdyStreamId, + SpdyStreamId, + SpdyPriority, + uint8, + bool, + bool, + const SpdyHeaderBlock&)); + MOCK_METHOD3(OnSynStream, void(SpdyStreamId, bool, const SpdyHeaderBlock&)); + MOCK_METHOD3(OnSynReply, void(SpdyStreamId, bool, const SpdyHeaderBlock&)); + MOCK_METHOD3(OnHeaders, void(SpdyStreamId, bool, const SpdyHeaderBlock&)); + MOCK_METHOD4(OnStreamFrameData, void(SpdyStreamId, + const char*, + size_t, + bool)); + MOCK_METHOD1(OnStreamFrameData, void(bool)); + MOCK_METHOD1(OnSettings, void(bool)); + MOCK_METHOD3(OnSetting, void(SpdySettingsIds, uint8, uint32)); + MOCK_METHOD1(OnPing, void(uint32)); + MOCK_METHOD2(OnRstStream, void(SpdyStreamId, SpdyRstStreamStatus)); + MOCK_METHOD2(OnGoAway, void(SpdyStreamId, SpdyGoAwayStatus)); + MOCK_METHOD2(OnWindowUpdate, void(SpdyStreamId, uint32)); + MOCK_METHOD2(OnPushPromise, void(SpdyStreamId, SpdyStreamId)); +}; + +class FakeSMConnection : public SMConnection { + public: + FakeSMConnection(EpollServer* epoll_server, + SSLState* ssl_state, + MemoryCache* memory_cache, + FlipAcceptor* acceptor, + std::string log_prefix) + : SMConnection(epoll_server, + ssl_state, + memory_cache, + acceptor, + log_prefix) {} + + MOCK_METHOD0(Cleanup, void()); + MOCK_METHOD8(InitSMConnection, void(SMConnectionPoolInterface*, + SMInterface*, + EpollServer*, + int, + std::string, + std::string, + std::string, + bool)); +}; + +class SpdySMWithMockSMInterfaceFactory : public SpdySM { + public: + virtual ~SpdySMWithMockSMInterfaceFactory() {} + SpdySMWithMockSMInterfaceFactory(SMConnection* connection, + SMInterface* sm_http_interface, + EpollServer* epoll_server, + MemoryCache* memory_cache, + FlipAcceptor* acceptor) + : SpdySM(connection, + sm_http_interface, + epoll_server, + memory_cache, + acceptor) {} + + MOCK_METHOD2(FindOrMakeNewSMConnectionInterface, + SMInterface*(const std::string&, const std::string&)); +}; + +class FlipSpdySMTest : public ::testing::Test { + public: + FlipSpdySMTest(SpdyMajorVersion version = SPDY2, + FlipHandlerType type = FLIP_HANDLER_PROXY) + : spdy_version_(version) { + SSLState* ssl_state = NULL; + mock_another_interface_.reset(new MockSMInterface); + memory_cache_.reset(new MemoryCache); + acceptor_.reset(new FlipAcceptor(type, + "127.0.0.1", + "8941", + "ssl_cert_filename", + "ssl_key_filename", + "127.0.0.1", + "8942", + "127.0.0.1", + "8943", + 1, + 0, + true, + 1, + false, + true, + NULL)); + epoll_server_.reset(new EpollServer); + connection_.reset(new FakeSMConnection(epoll_server_.get(), + ssl_state, + memory_cache_.get(), + acceptor_.get(), + "log_prefix")); + + interface_.reset(new SpdySMWithMockSMInterfaceFactory( + connection_.get(), + mock_another_interface_.get(), + epoll_server_.get(), + memory_cache_.get(), + acceptor_.get())); + + spdy_framer_.reset(new BufferedSpdyFramer(spdy_version_, true)); + spdy_framer_visitor_.reset(new SpdyFramerVisitor); + spdy_framer_->set_visitor(spdy_framer_visitor_.get()); + } + + virtual void TearDown() OVERRIDE { + if (acceptor_->listen_fd_ >= 0) { + epoll_server_->UnregisterFD(acceptor_->listen_fd_); + close(acceptor_->listen_fd_); + acceptor_->listen_fd_ = -1; + } + OutputList& output_list = *connection_->output_list(); + for (OutputList::const_iterator i = output_list.begin(); + i != output_list.end(); + ++i) { + delete *i; + } + output_list.clear(); + } + + bool HasStream(uint32 stream_id) { + return interface_->output_ordering().ExistsInPriorityMaps(stream_id); + } + + protected: + SpdyMajorVersion spdy_version_; + scoped_ptr<MockSMInterface> mock_another_interface_; + scoped_ptr<MemoryCache> memory_cache_; + scoped_ptr<FlipAcceptor> acceptor_; + scoped_ptr<EpollServer> epoll_server_; + scoped_ptr<FakeSMConnection> connection_; + scoped_ptr<SpdySMWithMockSMInterfaceFactory> interface_; + scoped_ptr<BufferedSpdyFramer> spdy_framer_; + scoped_ptr<SpdyFramerVisitor> spdy_framer_visitor_; +}; + +class FlipSpdy2SMTest : public FlipSpdySMTest { + public: + FlipSpdy2SMTest(): FlipSpdySMTest(SPDY2) {} + virtual ~FlipSpdy2SMTest() {} +}; + +class FlipSpdy2SMTestNonProxy : public FlipSpdySMTest { + public: + FlipSpdy2SMTestNonProxy(): FlipSpdySMTest(SPDY2, FLIP_HANDLER_SPDY_SERVER) {} + virtual ~FlipSpdy2SMTestNonProxy() {} +}; + +TEST_F(FlipSpdySMTest, InitSMConnection) { + { + InSequence s; + EXPECT_CALL(*connection_, InitSMConnection(_, _, _, _, _, _, _, _)); + } + interface_->InitSMConnection(NULL, + NULL, + epoll_server_.get(), + -1, + "", + "", + "", + false); +} + +TEST_F(FlipSpdySMTest, OnSynStream) { + BufferedSpdyFramerVisitorInterface* visitor = interface_.get(); + scoped_ptr<MockSMInterface> mock_interface(new MockSMInterface); + uint32 stream_id = 92; + uint32 associated_id = 43; + std::string expected = "GET /path HTTP/1.0\r\n" + "method: GET\r\n" + "scheme: http\r\n" + "url: http://www.example.com/path\r\n" + "version: HTTP/1.0\r\n\r\n"; + SpdyHeaderBlock block; + block["method"] = "GET"; + block["url"] = "http://www.example.com/path"; + block["scheme"] = "http"; + block["version"] = "HTTP/1.0"; + StringSaver saver; + { + InSequence s; + EXPECT_CALL(*interface_, + FindOrMakeNewSMConnectionInterface(_, _)) + .WillOnce(Return(mock_interface.get())); + EXPECT_CALL(*mock_interface, SetStreamID(stream_id)); + EXPECT_CALL(*mock_interface, ProcessWriteInput(_, _)) + .WillOnce(DoAll(SaveArg<0>(&saver.data), + SaveArg<1>(&saver.size), + InvokeWithoutArgs(&saver, &StringSaver::Save), + Return(0))); + } + visitor->OnSynStream(stream_id, associated_id, 0, 0, false, false, block); + ASSERT_EQ(expected, saver.string); +} + +TEST_F(FlipSpdySMTest, OnStreamFrameData) { + BufferedSpdyFramerVisitorInterface* visitor = interface_.get(); + scoped_ptr<MockSMInterface> mock_interface(new MockSMInterface); + uint32 stream_id = 92; + uint32 associated_id = 43; + SpdyHeaderBlock block; + testing::MockFunction<void(int)> checkpoint; + + scoped_ptr<SpdyFrame> frame(spdy_framer_->CreatePingFrame(12)); + block["method"] = "GET"; + block["url"] = "http://www.example.com/path"; + block["scheme"] = "http"; + block["version"] = "HTTP/1.0"; + { + InSequence s; + EXPECT_CALL(*interface_, + FindOrMakeNewSMConnectionInterface(_, _)) + .WillOnce(Return(mock_interface.get())); + EXPECT_CALL(*mock_interface, SetStreamID(stream_id)); + EXPECT_CALL(*mock_interface, ProcessWriteInput(_, _)).Times(1); + EXPECT_CALL(checkpoint, Call(0)); + EXPECT_CALL(*mock_interface, + ProcessWriteInput(frame->data(), frame->size())).Times(1); + } + + visitor->OnSynStream(stream_id, associated_id, 0, 0, false, false, block); + checkpoint.Call(0); + visitor->OnStreamFrameData(stream_id, frame->data(), frame->size(), true); +} + +TEST_F(FlipSpdySMTest, OnRstStream) { + BufferedSpdyFramerVisitorInterface* visitor = interface_.get(); + uint32 stream_id = 82; + MemCacheIter mci; + mci.stream_id = stream_id; + + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + mci.file_data = memory_cache_->GetFileData(filename); + } + + interface_->AddToOutputOrder(mci); + ASSERT_TRUE(HasStream(stream_id)); + visitor->OnRstStream(stream_id, RST_STREAM_INVALID); + ASSERT_FALSE(HasStream(stream_id)); +} + +TEST_F(FlipSpdySMTest, ProcessReadInput) { + ASSERT_EQ(SpdyFramer::SPDY_RESET, interface_->spdy_framer()->state()); + interface_->ProcessReadInput("", 1); + ASSERT_EQ(SpdyFramer::SPDY_READING_COMMON_HEADER, + interface_->spdy_framer()->state()); +} + +TEST_F(FlipSpdySMTest, ResetForNewConnection) { + uint32 stream_id = 13; + MemCacheIter mci; + mci.stream_id = stream_id; + // incomplete input + const char input[] = {'\0', '\0', '\0'}; + + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + mci.file_data = memory_cache_->GetFileData(filename); + } + + interface_->AddToOutputOrder(mci); + ASSERT_TRUE(HasStream(stream_id)); + interface_->ProcessReadInput(input, sizeof(input)); + ASSERT_NE(SpdyFramer::SPDY_RESET, interface_->spdy_framer()->state()); + + interface_->ResetForNewConnection(); + ASSERT_FALSE(HasStream(stream_id)); + ASSERT_EQ(SpdyFramer::SPDY_RESET, interface_->spdy_framer()->state()); +} + +TEST_F(FlipSpdySMTest, PostAcceptHook) { + interface_->PostAcceptHook(); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, OnSettings(false)); + EXPECT_CALL(*spdy_framer_visitor_, OnSetting( + SETTINGS_MAX_CONCURRENT_STREAMS, 0u, 100u)); + } + spdy_framer_->ProcessInput(df->data, df->size); +} + +TEST_F(FlipSpdySMTest, NewStream) { + // TODO(yhirano): SpdySM::NewStream leads to crash when + // acceptor_->flip_handler_type_ != FLIP_HANDLER_SPDY_SERVER. + // It should be fixed though I don't know the solution now. +} + +TEST_F(FlipSpdySMTest, AddToOutputOrder) { + uint32 stream_id = 13; + MemCacheIter mci; + mci.stream_id = stream_id; + + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + mci.file_data = memory_cache_->GetFileData(filename); + } + + interface_->AddToOutputOrder(mci); + ASSERT_TRUE(HasStream(stream_id)); +} + +TEST_F(FlipSpdySMTest, SendErrorNotFound) { + uint32 stream_id = 82; + SpdyHeaderBlock actual_header_block; + const char* actual_data; + size_t actual_size; + testing::MockFunction<void(int)> checkpoint; + + interface_->SendErrorNotFound(stream_id); + + ASSERT_EQ(2u, connection_->output_list()->size()); + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, + OnSynReply(stream_id, false, _)) + .WillOnce(SaveArg<2>(&actual_header_block)); + EXPECT_CALL(checkpoint, Call(0)); + EXPECT_CALL(*spdy_framer_visitor_, + OnStreamFrameData(stream_id, _, _, false)).Times(1) + .WillOnce(DoAll(SaveArg<1>(&actual_data), + SaveArg<2>(&actual_size))); + EXPECT_CALL(*spdy_framer_visitor_, + OnStreamFrameData(stream_id, NULL, 0, true)).Times(1); + } + + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + checkpoint.Call(0); + df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + + ASSERT_EQ(2, spdy_framer_->frames_received()); + ASSERT_EQ(2u, actual_header_block.size()); + ASSERT_EQ("404 Not Found", actual_header_block["status"]); + ASSERT_EQ("HTTP/1.1", actual_header_block["version"]); + ASSERT_EQ("wtf?", StringPiece(actual_data, actual_size)); +} + +TEST_F(FlipSpdySMTest, SendSynStream) { + uint32 stream_id = 82; + BalsaHeaders headers; + SpdyHeaderBlock actual_header_block; + headers.AppendHeader("key1", "value1"); + headers.SetRequestFirstlineFromStringPieces("GET", "/path", "HTTP/1.0"); + + interface_->SendSynStream(stream_id, headers); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, + OnSynStream(stream_id, 0, _, _, false, false, _)) + .WillOnce(SaveArg<6>(&actual_header_block)); + } + + spdy_framer_->ProcessInput(df->data, df->size); + ASSERT_EQ(1, spdy_framer_->frames_received()); + ASSERT_EQ(4u, actual_header_block.size()); + ASSERT_EQ("GET", actual_header_block["method"]); + ASSERT_EQ("HTTP/1.0", actual_header_block["version"]); + ASSERT_EQ("/path", actual_header_block["url"]); + ASSERT_EQ("value1", actual_header_block["key1"]); +} + +TEST_F(FlipSpdySMTest, SendSynReply) { + uint32 stream_id = 82; + BalsaHeaders headers; + SpdyHeaderBlock actual_header_block; + headers.AppendHeader("key1", "value1"); + headers.SetResponseFirstlineFromStringPieces("HTTP/1.1", "200", "OK"); + + interface_->SendSynReply(stream_id, headers); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, OnSynReply(stream_id, false, _)) + .WillOnce(SaveArg<2>(&actual_header_block)); + } + + spdy_framer_->ProcessInput(df->data, df->size); + ASSERT_EQ(1, spdy_framer_->frames_received()); + ASSERT_EQ(3u, actual_header_block.size()); + ASSERT_EQ("200 OK", actual_header_block["status"]); + ASSERT_EQ("HTTP/1.1", actual_header_block["version"]); + ASSERT_EQ("value1", actual_header_block["key1"]); +} + +TEST_F(FlipSpdySMTest, SendDataFrame) { + uint32 stream_id = 133; + SpdyDataFlags flags = DATA_FLAG_NONE; + const char* actual_data; + size_t actual_size; + + interface_->SendDataFrame(stream_id, "hello", 5, flags, true); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, + OnStreamFrameData(stream_id, _, _, false)) + .WillOnce(DoAll(SaveArg<1>(&actual_data), + SaveArg<2>(&actual_size))); + } + + spdy_framer_->ProcessInput(df->data, df->size); + ASSERT_EQ(1, spdy_framer_->frames_received()); + ASSERT_EQ("hello", StringPiece(actual_data, actual_size)); +} + +TEST_F(FlipSpdySMTest, SendLongDataFrame) { + uint32 stream_id = 133; + SpdyDataFlags flags = DATA_FLAG_NONE; + const char* actual_data; + size_t actual_size; + + std::string data = + std::string(kSpdySegmentSize, 'a') + + std::string(kSpdySegmentSize, 'b') + + "c"; + interface_->SendDataFrame(stream_id, data.data(), data.size(), flags, true); + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, + OnStreamFrameData(stream_id, _, _, false)).Times(3) + .WillRepeatedly(DoAll(SaveArg<1>(&actual_data), + SaveArg<2>(&actual_size))); + } + + ASSERT_EQ(3u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + ASSERT_EQ(std::string(kSpdySegmentSize, 'a'), + StringPiece(actual_data, actual_size)); + + df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + ASSERT_EQ(std::string(kSpdySegmentSize, 'b'), + StringPiece(actual_data, actual_size)); + + df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + ASSERT_EQ("c", StringPiece(actual_data, actual_size)); +} + +TEST_F(FlipSpdy2SMTest, SendEOF) { + uint32 stream_id = 82; + // SPDY2 data frame + char empty_data_frame[] = {'\0', '\0', '\0', '\x52', '\x1', '\0', '\0', '\0'}; + MemCacheIter mci; + mci.stream_id = stream_id; + + { + BalsaHeaders headers; + std::string filename = "foobar"; + memory_cache_->InsertFile(&headers, filename, ""); + mci.file_data = memory_cache_->GetFileData(filename); + } + + interface_->AddToOutputOrder(mci); + ASSERT_TRUE(HasStream(stream_id)); + interface_->SendEOF(stream_id); + ASSERT_FALSE(HasStream(stream_id)); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + ASSERT_EQ(StringPiece(empty_data_frame, sizeof(empty_data_frame)), + StringPiece(df->data, df->size)); +} + +TEST_F(FlipSpdy2SMTest, SendEmptyDataFrame) { + uint32 stream_id = 133; + SpdyDataFlags flags = DATA_FLAG_NONE; + // SPDY2 data frame + char expected[] = {'\0', '\0', '\0', '\x85', '\0', '\0', '\0', '\0'}; + + interface_->SendDataFrame(stream_id, "hello", 0, flags, true); + + ASSERT_EQ(1u, connection_->output_list()->size()); + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + + ASSERT_EQ(StringPiece(expected, sizeof(expected)), + StringPiece(df->data, df->size)); +} + +TEST_F(FlipSpdy2SMTestNonProxy, OnSynStream) { + BufferedSpdyFramerVisitorInterface* visitor = interface_.get(); + uint32 stream_id = 82; + SpdyHeaderBlock spdy_headers; + spdy_headers["url"] = "http://www.example.com/path"; + spdy_headers["method"] = "GET"; + spdy_headers["scheme"] = "http"; + + { + BalsaHeaders headers; + memory_cache_->InsertFile(&headers, "GET_/path", ""); + } + visitor->OnSynStream(stream_id, 0, 0, 0, true, true, spdy_headers); + ASSERT_TRUE(HasStream(stream_id)); +} + +TEST_F(FlipSpdy2SMTestNonProxy, NewStream) { + uint32 stream_id = 13; + std::string filename = "foobar"; + + { + BalsaHeaders headers; + memory_cache_->InsertFile(&headers, filename, ""); + } + + interface_->NewStream(stream_id, 0, filename); + ASSERT_TRUE(HasStream(stream_id)); +} + +TEST_F(FlipSpdy2SMTestNonProxy, NewStreamError) { + uint32 stream_id = 82; + SpdyHeaderBlock actual_header_block; + const char* actual_data; + size_t actual_size; + testing::MockFunction<void(int)> checkpoint; + + interface_->NewStream(stream_id, 0, "nonexistingfile"); + + ASSERT_EQ(2u, connection_->output_list()->size()); + + { + InSequence s; + EXPECT_CALL(*spdy_framer_visitor_, + OnSynReply(stream_id, false, _)) + .WillOnce(SaveArg<2>(&actual_header_block)); + EXPECT_CALL(checkpoint, Call(0)); + EXPECT_CALL(*spdy_framer_visitor_, + OnStreamFrameData(stream_id, _, _, false)).Times(1) + .WillOnce(DoAll(SaveArg<1>(&actual_data), + SaveArg<2>(&actual_size))); + EXPECT_CALL(*spdy_framer_visitor_, + OnStreamFrameData(stream_id, NULL, 0, true)).Times(1); + } + + std::list<DataFrame*>::const_iterator i = connection_->output_list()->begin(); + DataFrame* df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + checkpoint.Call(0); + df = *i++; + spdy_framer_->ProcessInput(df->data, df->size); + + ASSERT_EQ(2, spdy_framer_->frames_received()); + ASSERT_EQ(2u, actual_header_block.size()); + ASSERT_EQ("404 Not Found", actual_header_block["status"]); + ASSERT_EQ("HTTP/1.1", actual_header_block["version"]); + ASSERT_EQ("wtf?", StringPiece(actual_data, actual_size)); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/tools/gdig/gdig.cc b/chromium/net/tools/gdig/gdig.cc index 0dec8d7aaa0..54f05091c01 100644 --- a/chromium/net/tools/gdig/gdig.cc +++ b/chromium/net/tools/gdig/gdig.cc @@ -116,7 +116,7 @@ typedef std::vector<ReplayLogEntry> ReplayLog; // The file should be sorted by timestamp in ascending time. bool LoadReplayLog(const base::FilePath& file_path, ReplayLog* replay_log) { std::string original_replay_log_contents; - if (!file_util::ReadFileToString(file_path, &original_replay_log_contents)) { + if (!base::ReadFileToString(file_path, &original_replay_log_contents)) { fprintf(stderr, "Unable to open replay file %s\n", file_path.MaybeAsASCII().c_str()); return false; @@ -460,7 +460,11 @@ void GDig::ReplayNextEntry() { ++active_resolves_; ++replay_log_index_; int ret = resolver_->Resolve( - info, addrlist, callback, NULL, + info, + DEFAULT_PRIORITY, + addrlist, + callback, + NULL, BoundNetLog::Make(log_.get(), net::NetLog::SOURCE_NONE)); if (ret != ERR_IO_PENDING) callback.Run(ret); diff --git a/chromium/net/tools/quic/end_to_end_test.cc b/chromium/net/tools/quic/end_to_end_test.cc index 3903df965ee..7b61be6d149 100644 --- a/chromium/net/tools/quic/end_to_end_test.cc +++ b/chromium/net/tools/quic/end_to_end_test.cc @@ -61,13 +61,18 @@ void GenerateBody(string* body, int length) { // Simple wrapper class to run server in a thread. class ServerThread : public base::SimpleThread { public: - explicit ServerThread(IPEndPoint address, const QuicConfig& config) + ServerThread(IPEndPoint address, + const QuicConfig& config, + bool strike_register_no_startup_period) : SimpleThread("server_thread"), listening_(true, false), quit_(true, false), server_(config), address_(address), port_(0) { + if (strike_register_no_startup_period) { + server_.SetStrikeRegisterNoStartupPeriod(); + } } virtual ~ServerThread() { } @@ -116,7 +121,8 @@ class EndToEndTest : public ::testing::TestWithParam<QuicVersion> { protected: EndToEndTest() : server_hostname_("example.com"), - server_started_(false) { + server_started_(false), + strike_register_no_startup_period_(false) { net::IPAddressNumber ip; CHECK(net::ParseIPLiteralToNumber("127.0.0.1", &ip)); server_address_ = IPEndPoint(ip, 0); @@ -154,7 +160,8 @@ class EndToEndTest : public ::testing::TestWithParam<QuicVersion> { } void StartServer() { - server_thread_.reset(new ServerThread(server_address_, server_config_)); + server_thread_.reset(new ServerThread(server_address_, server_config_, + strike_register_no_startup_period_)); server_thread_->Start(); server_thread_->listening()->Wait(); server_address_ = IPEndPoint(server_address_.address(), @@ -210,6 +217,7 @@ class EndToEndTest : public ::testing::TestWithParam<QuicVersion> { QuicConfig client_config_; QuicConfig server_config_; QuicVersion version_; + bool strike_register_no_startup_period_; }; // Run all end to end tests with all supported versions. @@ -415,7 +423,29 @@ TEST_P(EndToEndTest, PostMissingBytes) { EXPECT_EQ(500u, client_->response_headers()->parsed_response_code()); } -TEST_P(EndToEndTest, LargePost) { +TEST_P(EndToEndTest, LargePostNoPacketLoss) { + // TODO(rtenneti): Delete this when NSS is supported. + if (!Aes128Gcm12Encrypter::IsSupported()) { + LOG(INFO) << "AES GCM not supported. Test skipped."; + return; + } + + ASSERT_TRUE(Initialize()); + + client_->client()->WaitForCryptoHandshakeConfirmed(); + + // 1 Mb body. + string body; + GenerateBody(&body, 1024 * 1024); + + HTTPMessage request(HttpConstants::HTTP_1_1, + HttpConstants::POST, "/foo"); + request.AddBody(body, true); + + EXPECT_EQ(kFooResponseBody, client_->SendCustomSynchronousRequest(request)); +} + +TEST_P(EndToEndTest, LargePostWithPacketLoss) { // TODO(rtenneti): Delete this when NSS is supported. if (!Aes128Gcm12Encrypter::IsSupported()) { LOG(INFO) << "AES GCM not supported. Test skipped."; @@ -431,8 +461,9 @@ TEST_P(EndToEndTest, LargePost) { client_->client()->WaitForCryptoHandshakeConfirmed(); // FLAGS_fake_packet_loss_percentage = 30; + // 10 Kb body. string body; - GenerateBody(&body, 10240); + GenerateBody(&body, 1024 * 10); HTTPMessage request(HttpConstants::HTTP_1_1, HttpConstants::POST, "/foo"); @@ -441,6 +472,44 @@ TEST_P(EndToEndTest, LargePost) { EXPECT_EQ(kFooResponseBody, client_->SendCustomSynchronousRequest(request)); } +TEST_P(EndToEndTest, LargePostZeroRTTFailure) { + // Have the server accept 0-RTT without waiting a startup period. + strike_register_no_startup_period_ = true; + + // Send a request and then disconnect. This prepares the client to attempt + // a 0-RTT handshake for the next request. + ASSERT_TRUE(Initialize()); + + string body; + GenerateBody(&body, 20480); + + HTTPMessage request(HttpConstants::HTTP_1_1, + HttpConstants::POST, "/foo"); + request.AddBody(body, true); + + EXPECT_EQ(kFooResponseBody, client_->SendCustomSynchronousRequest(request)); + EXPECT_EQ(2, client_->client()->session()->GetNumSentClientHellos()); + + client_->Disconnect(); + + // The 0-RTT handshake should succeed. + client_->Connect(); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendCustomSynchronousRequest(request)); + EXPECT_EQ(1, client_->client()->session()->GetNumSentClientHellos()); + + client_->Disconnect(); + + // Restart the server so that the 0-RTT handshake will take 1 RTT. + StopServer(); + StartServer(); + + client_->Connect(); + ASSERT_TRUE(client_->client()->connected()); + EXPECT_EQ(kFooResponseBody, client_->SendCustomSynchronousRequest(request)); + EXPECT_EQ(2, client_->client()->session()->GetNumSentClientHellos()); +} + // TODO(ianswett): Enable once b/9295090 is fixed. TEST_P(EndToEndTest, DISABLED_LargePostFEC) { // FLAGS_fake_packet_loss_percentage = 30; @@ -517,7 +586,6 @@ TEST_P(EndToEndTest, DISABLED_MultipleTermination) { } ASSERT_TRUE(Initialize()); - scoped_ptr<QuicTestClient> client2(CreateQuicClient()); HTTPMessage request(HttpConstants::HTTP_1_1, HttpConstants::POST, "/foo"); @@ -597,6 +665,28 @@ TEST_P(EndToEndTest, ResetConnection) { EXPECT_EQ(200u, client_->response_headers()->parsed_response_code()); } +TEST_P(EndToEndTest, MaxStreamsUberTest) { + // FLAGS_fake_packet_loss_percentage = 1; + ASSERT_TRUE(Initialize()); + string large_body; + GenerateBody(&large_body, 10240); + int max_streams = 100; + + AddToCache("GET", "/large_response", "HTTP/1.1", "200", "OK", large_body);; + + client_->client()->WaitForCryptoHandshakeConfirmed(); + // FLAGS_fake_packet_loss_percentage = 10; + + for (int i = 0; i < max_streams; ++i) { + EXPECT_LT(0, client_->SendRequest("/large_response")); + } + + // WaitForEvents waits 50ms and returns true if there are outstanding + // requests. + while (client_->client()->WaitForEvents() == true) { + } +} + class WrongAddressWriter : public QuicPacketWriter { public: explicit WrongAddressWriter(int fd) : fd_(fd) { diff --git a/chromium/net/tools/quic/quic_client.h b/chromium/net/tools/quic/quic_client.h index ca20a8d2158..5e8960124d1 100644 --- a/chromium/net/tools/quic/quic_client.h +++ b/chromium/net/tools/quic/quic_client.h @@ -102,8 +102,6 @@ class QuicClient : public EpollCallbackInterface { bool connected() const; - int packets_dropped() { return packets_dropped_; } - void set_bind_to_address(IPAddressNumber address) { bind_to_address_ = address; } @@ -112,8 +110,6 @@ class QuicClient : public EpollCallbackInterface { void set_local_port(int local_port) { local_port_ = local_port; } - int local_port() { return local_port_; } - const IPEndPoint& server_address() const { return server_address_; } const IPEndPoint& client_address() const { return client_address_; } diff --git a/chromium/net/tools/quic/quic_client_session.cc b/chromium/net/tools/quic/quic_client_session.cc index c41df6702b5..f993908b624 100644 --- a/chromium/net/tools/quic/quic_client_session.cc +++ b/chromium/net/tools/quic/quic_client_session.cc @@ -55,6 +55,10 @@ bool QuicClientSession::CryptoConnect() { return crypto_stream_.CryptoConnect(); } +int QuicClientSession::GetNumSentClientHellos() const { + return crypto_stream_.num_sent_client_hellos(); +} + ReliableQuicStream* QuicClientSession::CreateIncomingReliableStream( QuicStreamId id) { DLOG(ERROR) << "Server push not supported"; diff --git a/chromium/net/tools/quic/quic_client_session.h b/chromium/net/tools/quic/quic_client_session.h index f51aeeaff1b..a73d721fef6 100644 --- a/chromium/net/tools/quic/quic_client_session.h +++ b/chromium/net/tools/quic/quic_client_session.h @@ -39,6 +39,11 @@ class QuicClientSession : public QuicSession { // handshake is started successfully. bool CryptoConnect(); + // Returns the number of client hello messages that have been sent on the + // crypto stream. If the handshake has completed then this is one greater + // than the number of round-trips needed for the handshake. + int GetNumSentClientHellos() const; + protected: // QuicSession methods: virtual ReliableQuicStream* CreateIncomingReliableStream( diff --git a/chromium/net/tools/quic/quic_dispatcher.cc b/chromium/net/tools/quic/quic_dispatcher.cc index 68691f7dc58..5253764fb20 100644 --- a/chromium/net/tools/quic/quic_dispatcher.cc +++ b/chromium/net/tools/quic/quic_dispatcher.cc @@ -59,7 +59,7 @@ int QuicDispatcher::WritePacket(const char* buffer, size_t buf_len, QuicBlockedWriterInterface* writer, int* error) { if (write_blocked_) { - write_blocked_list_.AddBlockedObject(writer); + write_blocked_list_.insert(make_pair(writer, true)); *error = EAGAIN; return -1; } @@ -68,7 +68,7 @@ int QuicDispatcher::WritePacket(const char* buffer, size_t buf_len, self_address, peer_address, error); if (rc == -1 && (*error == EWOULDBLOCK || *error == EAGAIN)) { - write_blocked_list_.AddBlockedObject(writer); + write_blocked_list_.insert(make_pair(writer, true)); write_blocked_ = true; } return rc; @@ -114,7 +114,7 @@ void QuicDispatcher::ProcessPacket(const IPEndPoint& server_address, void QuicDispatcher::CleanUpSession(SessionMap::iterator it) { QuicSession* session = it->second; - write_blocked_list_.RemoveBlockedObject(session->connection()); + write_blocked_list_.erase(session->connection()); time_wait_list_manager_->AddGuidToTimeWait(it->first, session->connection()->version()); session_map_.erase(it); @@ -129,13 +129,13 @@ bool QuicDispatcher::OnCanWrite() { write_blocked_ = false; // Give each writer one attempt to write. - int num_writers = write_blocked_list_.NumObjects(); + int num_writers = write_blocked_list_.size(); for (int i = 0; i < num_writers; ++i) { - if (write_blocked_list_.IsEmpty()) { + if (write_blocked_list_.empty()) { break; } - QuicBlockedWriterInterface* writer = - write_blocked_list_.GetNextBlockedObject(); + QuicBlockedWriterInterface* writer = write_blocked_list_.begin()->first; + write_blocked_list_.erase(write_blocked_list_.begin()); bool can_write_more = writer->OnCanWrite(); if (write_blocked_) { // We were unable to write. Wait for the next EPOLLOUT. @@ -146,12 +146,12 @@ bool QuicDispatcher::OnCanWrite() { // The socket is not blocked but the writer has ceded work. Add it to the // end of the list. if (can_write_more) { - write_blocked_list_.AddBlockedObject(writer); + write_blocked_list_.insert(make_pair(writer, true)); } } // We're not write blocked. Return true if there's more work to do. - return !write_blocked_list_.IsEmpty(); + return !write_blocked_list_.empty(); } void QuicDispatcher::Shutdown() { diff --git a/chromium/net/tools/quic/quic_dispatcher.h b/chromium/net/tools/quic/quic_dispatcher.h index bbf8d9b4941..aea76cba7c2 100644 --- a/chromium/net/tools/quic/quic_dispatcher.h +++ b/chromium/net/tools/quic/quic_dispatcher.h @@ -12,7 +12,7 @@ #include "base/containers/hash_tables.h" #include "net/base/ip_endpoint.h" -#include "net/quic/blocked_list.h" +#include "net/base/linked_hash_map.h" #include "net/quic/quic_blocked_writer_interface.h" #include "net/quic/quic_protocol.h" #include "net/tools/flip_server/epoll_server.h" @@ -48,7 +48,8 @@ class QuicDispatcherPeer; class DeleteSessionsAlarm; class QuicDispatcher : public QuicPacketWriter, public QuicSessionOwner { public: - typedef BlockedList<QuicBlockedWriterInterface*> WriteBlockedList; + // Ideally we'd have a linked_hash_set: the boolean is unused. + typedef linked_hash_map<QuicBlockedWriterInterface*, bool> WriteBlockedList; // Due to the way delete_sessions_closure_ is registered, the Dispatcher // must live until epoll_server Shutdown. @@ -82,7 +83,6 @@ class QuicDispatcher : public QuicPacketWriter, public QuicSessionOwner { // Ensure that the closed connection is cleaned up asynchronously. virtual void OnConnectionClose(QuicGuid guid, QuicErrorCode error) OVERRIDE; - int fd() { return fd_; } void set_fd(int fd) { fd_ = fd; } typedef base::hash_map<QuicGuid, QuicSession*> SessionMap; diff --git a/chromium/net/tools/quic/quic_dispatcher_test.cc b/chromium/net/tools/quic/quic_dispatcher_test.cc index 059d8334dc0..2e2c3097802 100644 --- a/chromium/net/tools/quic/quic_dispatcher_test.cc +++ b/chromium/net/tools/quic/quic_dispatcher_test.cc @@ -22,6 +22,7 @@ using base::StringPiece; using net::EpollServer; using net::test::MockSession; using net::tools::test::MockConnection; +using std::make_pair; using testing::_; using testing::DoAll; using testing::Invoke; @@ -237,7 +238,7 @@ TEST_F(QuicDispatcherTest, TimeWaitListManager) { ProcessPacket(addr, guid, "foo"); } -class WriteBlockedListTest : public QuicDispatcherTest { +class QuicWriteBlockedListTest : public QuicDispatcherTest { public: virtual void SetUp() { IPEndPoint addr(Loopback4(), 1); @@ -270,12 +271,12 @@ class WriteBlockedListTest : public QuicDispatcherTest { QuicDispatcher::WriteBlockedList* blocked_list_; }; -TEST_F(WriteBlockedListTest, BasicOnCanWrite) { +TEST_F(QuicWriteBlockedListTest, BasicOnCanWrite) { // No OnCanWrite calls because no connections are blocked. dispatcher_.OnCanWrite(); // Register connection 1 for events, and make sure it's nofitied. - blocked_list_->AddBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); EXPECT_CALL(*connection1(), OnCanWrite()); dispatcher_.OnCanWrite(); @@ -284,67 +285,67 @@ TEST_F(WriteBlockedListTest, BasicOnCanWrite) { EXPECT_FALSE(dispatcher_.OnCanWrite()); } -TEST_F(WriteBlockedListTest, OnCanWriteOrder) { +TEST_F(QuicWriteBlockedListTest, OnCanWriteOrder) { // Make sure we handle events in order. InSequence s; - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection2()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection2(), true)); EXPECT_CALL(*connection1(), OnCanWrite()); EXPECT_CALL(*connection2(), OnCanWrite()); dispatcher_.OnCanWrite(); // Check the other ordering. - blocked_list_->AddBlockedObject(connection2()); - blocked_list_->AddBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection2(), true)); + blocked_list_->insert(make_pair(connection1(), true)); EXPECT_CALL(*connection2(), OnCanWrite()); EXPECT_CALL(*connection1(), OnCanWrite()); dispatcher_.OnCanWrite(); } -TEST_F(WriteBlockedListTest, OnCanWriteRemove) { +TEST_F(QuicWriteBlockedListTest, OnCanWriteRemove) { // Add and remove one connction. - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->RemoveBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->erase(connection1()); EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); dispatcher_.OnCanWrite(); // Add and remove one connction and make sure it doesn't affect others. - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection2()); - blocked_list_->RemoveBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection2(), true)); + blocked_list_->erase(connection1()); EXPECT_CALL(*connection2(), OnCanWrite()); dispatcher_.OnCanWrite(); // Add it, remove it, and add it back and make sure things are OK. - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->RemoveBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->erase(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); EXPECT_CALL(*connection1(), OnCanWrite()).Times(1); dispatcher_.OnCanWrite(); } -TEST_F(WriteBlockedListTest, DoubleAdd) { +TEST_F(QuicWriteBlockedListTest, DoubleAdd) { // Make sure a double add does not necessitate a double remove. - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->RemoveBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->erase(connection1()); EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); dispatcher_.OnCanWrite(); // Make sure a double add does not result in two OnCanWrite calls. - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection1()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection1(), true)); EXPECT_CALL(*connection1(), OnCanWrite()).Times(1); dispatcher_.OnCanWrite(); } -TEST_F(WriteBlockedListTest, OnCanWriteHandleBlock) { +TEST_F(QuicWriteBlockedListTest, OnCanWriteHandleBlock) { // Finally make sure if we write block on a write call, we stop calling. InSequence s; - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection2()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection2(), true)); EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce( - Invoke(this, &WriteBlockedListTest::SetBlocked)); + Invoke(this, &QuicWriteBlockedListTest::SetBlocked)); EXPECT_CALL(*connection2(), OnCanWrite()).Times(0); dispatcher_.OnCanWrite(); @@ -353,12 +354,12 @@ TEST_F(WriteBlockedListTest, OnCanWriteHandleBlock) { dispatcher_.OnCanWrite(); } -TEST_F(WriteBlockedListTest, LimitedWrites) { +TEST_F(QuicWriteBlockedListTest, LimitedWrites) { // Make sure we call both writers. The first will register for more writing // but should not be immediately called due to limits. InSequence s; - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection2()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection2(), true)); EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(Return(true)); EXPECT_CALL(*connection2(), OnCanWrite()).WillOnce(Return(false)); dispatcher_.OnCanWrite(); @@ -368,13 +369,13 @@ TEST_F(WriteBlockedListTest, LimitedWrites) { dispatcher_.OnCanWrite(); } -TEST_F(WriteBlockedListTest, TestWriteLimits) { +TEST_F(QuicWriteBlockedListTest, TestWriteLimits) { // Finally make sure if we write block on a write call, we stop calling. InSequence s; - blocked_list_->AddBlockedObject(connection1()); - blocked_list_->AddBlockedObject(connection2()); + blocked_list_->insert(make_pair(connection1(), true)); + blocked_list_->insert(make_pair(connection2(), true)); EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce( - Invoke(this, &WriteBlockedListTest::SetBlocked)); + Invoke(this, &QuicWriteBlockedListTest::SetBlocked)); EXPECT_CALL(*connection2(), OnCanWrite()).Times(0); dispatcher_.OnCanWrite(); diff --git a/chromium/net/tools/quic/quic_epoll_clock_test.cc b/chromium/net/tools/quic/quic_epoll_clock_test.cc index c774d0adba2..1ed1d15c1d7 100644 --- a/chromium/net/tools/quic/quic_epoll_clock_test.cc +++ b/chromium/net/tools/quic/quic_epoll_clock_test.cc @@ -37,7 +37,7 @@ TEST(QuicEpollClockTest, NowInUsec) { clock.Now().Subtract(QuicTime::Zero()).ToMicroseconds()); } -TEST(QuicClockTest, WallNow) { +TEST(QuicEpollClockTest, WallNow) { MockEpollServer epoll_server; QuicEpollClock clock(&epoll_server); diff --git a/chromium/net/tools/quic/quic_epoll_connection_helper_test.cc b/chromium/net/tools/quic/quic_epoll_connection_helper_test.cc index 0bfaee27d07..636d98f7ee5 100644 --- a/chromium/net/tools/quic/quic_epoll_connection_helper_test.cc +++ b/chromium/net/tools/quic/quic_epoll_connection_helper_test.cc @@ -8,6 +8,7 @@ #include "net/quic/crypto/quic_decrypter.h" #include "net/quic/crypto/quic_encrypter.h" #include "net/quic/crypto/quic_random.h" +#include "net/quic/quic_connection.h" #include "net/quic/quic_framer.h" #include "net/quic/test_tools/quic_connection_peer.h" #include "net/quic/test_tools/quic_test_utils.h" @@ -15,13 +16,13 @@ #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" -using net::test::GetMinStreamFrameSize; using net::test::FramerVisitorCapturingFrames; using net::test::MockSendAlgorithm; using net::test::QuicConnectionPeer; using net::test::MockConnectionVisitor; using net::tools::test::MockEpollServer; using testing::_; +using testing::AnyNumber; using testing::Return; namespace net { @@ -88,6 +89,12 @@ class QuicEpollConnectionHelperTest : public ::testing::Test { epoll_server_.set_timeout_in_us(-1); EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, _, _, _)). WillRepeatedly(Return(QuicTime::Delta::Zero())); + EXPECT_CALL(*send_algorithm_, BandwidthEstimate()).WillRepeatedly(Return( + QuicBandwidth::FromKBitsPerSecond(100))); + EXPECT_CALL(*send_algorithm_, SmoothedRtt()).WillRepeatedly(Return( + QuicTime::Delta::FromMilliseconds(100))); + ON_CALL(*send_algorithm_, SentPacket(_, _, _, _, _)) + .WillByDefault(Return(true)); } QuicPacket* ConstructDataPacket(QuicPacketSequenceNumber number, @@ -127,16 +134,20 @@ TEST_F(QuicEpollConnectionHelperTest, DISABLED_TestRetransmission) { const char buffer[] = "foo"; const size_t packet_size = - GetPacketHeaderSize(PACKET_8BYTE_GUID, kIncludeVersion, - PACKET_6BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP) + - GetMinStreamFrameSize(framer_.version()) + arraysize(buffer) - 1; + QuicPacketCreator::StreamFramePacketOverhead( + framer_.version(), PACKET_8BYTE_GUID, kIncludeVersion, + PACKET_1BYTE_SEQUENCE_NUMBER, NOT_IN_FEC_GROUP) + + arraysize(buffer) - 1; + EXPECT_CALL(*send_algorithm_, - SentPacket(_, 1, packet_size, NOT_RETRANSMISSION)); + SentPacket(_, 1, packet_size, NOT_RETRANSMISSION, _)); EXPECT_CALL(*send_algorithm_, AbandoningPacket(1, packet_size)); - connection_.SendStreamData(1, buffer, 0, false); + struct iovec iov = {const_cast<char*>(buffer), + static_cast<size_t>(3)}; + connection_.SendvStreamData(1, &iov, 1, 0, false); EXPECT_EQ(1u, helper_->header()->packet_sequence_number); EXPECT_CALL(*send_algorithm_, - SentPacket(_, 2, packet_size, IS_RETRANSMISSION)); + SentPacket(_, 2, packet_size, IS_RETRANSMISSION, _)); epoll_server_.AdvanceByAndCallCallbacks(kDefaultRetransmissionTimeMs * 1000); EXPECT_EQ(2u, helper_->header()->packet_sequence_number); @@ -145,7 +156,10 @@ TEST_F(QuicEpollConnectionHelperTest, DISABLED_TestRetransmission) { TEST_F(QuicEpollConnectionHelperTest, InitialTimeout) { EXPECT_TRUE(connection_.connected()); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA)); + EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillOnce( + Return(QuicTime::Delta::FromMicroseconds(1))); EXPECT_CALL(visitor_, ConnectionClose(QUIC_CONNECTION_TIMED_OUT, !kFromPeer)); epoll_server_.WaitForEventsAndExecuteCallbacks(); EXPECT_FALSE(connection_.connected()); @@ -162,7 +176,8 @@ TEST_F(QuicEpollConnectionHelperTest, TimeoutAfterSend) { EXPECT_EQ(5000, epoll_server_.NowInUsec()); // Send an ack so we don't set the retransmission alarm. - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, + SentPacket(_, 1, _, NOT_RETRANSMISSION, NO_RETRANSMITTABLE_DATA)); connection_.SendAck(); // The original alarm will fire. We should not time out because we had a @@ -172,7 +187,10 @@ TEST_F(QuicEpollConnectionHelperTest, TimeoutAfterSend) { // This time, we should time out. EXPECT_CALL(visitor_, ConnectionClose(QUIC_CONNECTION_TIMED_OUT, !kFromPeer)); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, NOT_RETRANSMISSION)); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 2, _, NOT_RETRANSMISSION, + HAS_RETRANSMITTABLE_DATA)); + EXPECT_CALL(*send_algorithm_, RetransmissionDelay()).WillOnce( + Return(QuicTime::Delta::FromMicroseconds(1))); epoll_server_.WaitForEventsAndExecuteCallbacks(); EXPECT_EQ(kDefaultInitialTimeoutSecs * 1000000 + 5000, epoll_server_.NowInUsec()); @@ -186,17 +204,21 @@ TEST_F(QuicEpollConnectionHelperTest, SendSchedulerDelayThenSend) { QuicPacket* packet = ConstructDataPacket(1, 0); EXPECT_CALL( *send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillOnce( - testing::Return(QuicTime::Delta::FromMicroseconds(1))); + Return(QuicTime::Delta::FromMicroseconds(1))); connection_.SendOrQueuePacket(ENCRYPTION_NONE, 1, packet, 0, - HAS_RETRANSMITTABLE_DATA); - EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION)); + HAS_RETRANSMITTABLE_DATA, + QuicConnection::NO_FORCE); + EXPECT_CALL(*send_algorithm_, SentPacket(_, 1, _, NOT_RETRANSMISSION, + _)); EXPECT_EQ(1u, connection_.NumQueuedPackets()); // Advance the clock to fire the alarm, and configure the scheduler // to permit the packet to be sent. - EXPECT_CALL(*send_algorithm_, TimeUntilSend(_, NOT_RETRANSMISSION, _, _)). - WillRepeatedly(testing::Return(QuicTime::Delta::Zero())); - EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(testing::Return(true)); + EXPECT_CALL(*send_algorithm_, + TimeUntilSend(_, NOT_RETRANSMISSION, _, _)).WillRepeatedly( + Return(QuicTime::Delta::Zero())); + EXPECT_CALL(visitor_, OnCanWrite()).WillOnce(Return(true)); + EXPECT_CALL(visitor_, HasPendingHandshake()).Times(AnyNumber()); epoll_server_.AdvanceByAndCallCallbacks(1); EXPECT_EQ(0u, connection_.NumQueuedPackets()); } diff --git a/chromium/net/tools/quic/quic_in_memory_cache.cc b/chromium/net/tools/quic/quic_in_memory_cache.cc index b840d79370d..6ed7dd5f832 100644 --- a/chromium/net/tools/quic/quic_in_memory_cache.cc +++ b/chromium/net/tools/quic/quic_in_memory_cache.cc @@ -140,7 +140,7 @@ void QuicInMemoryCache::Initialize() { BalsaHeaders request_headers, response_headers; string file_contents; - file_util::ReadFileToString(file, &file_contents); + base::ReadFileToString(file, &file_contents); // Frame HTTP. CachingBalsaVisitor caching_visitor; diff --git a/chromium/net/tools/quic/quic_reliable_server_stream_test.cc b/chromium/net/tools/quic/quic_reliable_server_stream_test.cc index b946d94cfc3..53533a35cb1 100644 --- a/chromium/net/tools/quic/quic_reliable_server_stream_test.cc +++ b/chromium/net/tools/quic/quic_reliable_server_stream_test.cc @@ -58,7 +58,9 @@ class QuicReliableServerStreamTest : public ::testing::Test { stream_.reset(new QuicSpdyServerStream(3, &session_)); } - QuicConsumedData ValidateHeaders(StringPiece headers) { + QuicConsumedData ValidateHeaders(const struct iovec* iov) { + StringPiece headers = + StringPiece(static_cast<const char*>(iov[0].iov_base), iov[0].iov_len); headers_string_ = SpdyUtils::SerializeResponseHeaders( response_headers_); QuicSpdyDecompressor decompressor; @@ -119,13 +121,20 @@ class QuicReliableServerStreamTest : public ::testing::Test { string body_; }; -QuicConsumedData ConsumeAllData(QuicStreamId id, StringPiece data, - QuicStreamOffset offset, bool fin) { - return QuicConsumedData(data.size(), fin); +QuicConsumedData ConsumeAllData(QuicStreamId id, + const struct iovec* iov, + int iov_count, + QuicStreamOffset offset, + bool fin) { + ssize_t consumed_length = 0; + for (int i = 0; i < iov_count; ++i) { + consumed_length += iov[i].iov_len; + } + return QuicConsumedData(consumed_length, fin); } TEST_F(QuicReliableServerStreamTest, TestFraming) { - EXPECT_CALL(session_, WriteData(_, _, _, _)).Times(AnyNumber()). + EXPECT_CALL(session_, WritevData(_, _, _, _, _)).Times(AnyNumber()). WillRepeatedly(Invoke(ConsumeAllData)); EXPECT_EQ(headers_string_.size(), stream_->ProcessData( @@ -138,7 +147,7 @@ TEST_F(QuicReliableServerStreamTest, TestFraming) { } TEST_F(QuicReliableServerStreamTest, TestFramingOnePacket) { - EXPECT_CALL(session_, WriteData(_, _, _, _)).Times(AnyNumber()). + EXPECT_CALL(session_, WritevData(_, _, _, _, _)).Times(AnyNumber()). WillRepeatedly(Invoke(ConsumeAllData)); string message = headers_string_ + body_; @@ -156,7 +165,7 @@ TEST_F(QuicReliableServerStreamTest, TestFramingExtraData) { string large_body = "hello world!!!!!!"; // We'll automatically write out an error (headers + body) - EXPECT_CALL(session_, WriteData(_, _, _, _)).Times(2). + EXPECT_CALL(session_, WritevData(_, _, _, _, _)).Times(2). WillRepeatedly(Invoke(ConsumeAllData)); EXPECT_EQ(headers_string_.size(), stream_->ProcessData( @@ -183,11 +192,11 @@ TEST_F(QuicReliableServerStreamTest, TestSendResponse) { response_headers_.ReplaceOrAppendHeader("content-length", "3"); InSequence s; - EXPECT_CALL(session_, WriteData(_, _, _, _)).Times(1) + EXPECT_CALL(session_, WritevData(_, _, 1, _, _)).Times(1) .WillOnce(WithArgs<1>(Invoke( this, &QuicReliableServerStreamTest::ValidateHeaders))); - StringPiece kBody = "Yum"; - EXPECT_CALL(session_, WriteData(_, kBody, _, _)).Times(1). + + EXPECT_CALL(session_, WritevData(_, _, 1, _, _)).Times(1). WillOnce(Return(QuicConsumedData(3, true))); stream_->SendResponse(); @@ -201,11 +210,11 @@ TEST_F(QuicReliableServerStreamTest, TestSendErrorResponse) { response_headers_.ReplaceOrAppendHeader("content-length", "3"); InSequence s; - EXPECT_CALL(session_, WriteData(_, _, _, _)).Times(1) + EXPECT_CALL(session_, WritevData(_, _, 1, _, _)).Times(1) .WillOnce(WithArgs<1>(Invoke( this, &QuicReliableServerStreamTest::ValidateHeaders))); - StringPiece kBody = "bad"; - EXPECT_CALL(session_, WriteData(_, kBody, _, _)).Times(1). + + EXPECT_CALL(session_, WritevData(_, _, 1, _, _)).Times(1). WillOnce(Return(QuicConsumedData(3, true))); stream_->SendErrorResponse(); diff --git a/chromium/net/tools/quic/quic_server.cc b/chromium/net/tools/quic/quic_server.cc index d18525f634f..9eae4ab7566 100644 --- a/chromium/net/tools/quic/quic_server.cc +++ b/chromium/net/tools/quic/quic_server.cc @@ -36,6 +36,7 @@ namespace tools { QuicServer::QuicServer() : port_(0), + fd_(-1), packets_dropped_(0), overflow_supported_(false), use_recvmmsg_(false), @@ -47,6 +48,7 @@ QuicServer::QuicServer() QuicServer::QuicServer(const QuicConfig& config) : port_(0), + fd_(-1), packets_dropped_(0), overflow_supported_(false), use_recvmmsg_(false), @@ -154,6 +156,9 @@ void QuicServer::Shutdown() { // Before we shut down the epoll server, give all active sessions a chance to // notify clients that they're closing. dispatcher_->Shutdown(); + + close(fd_); + fd_ = -1; } void QuicServer::OnEvent(int fd, EpollEvent* event) { diff --git a/chromium/net/tools/quic/quic_server.h b/chromium/net/tools/quic/quic_server.h index 142d1d1572c..90863ce6400 100644 --- a/chromium/net/tools/quic/quic_server.h +++ b/chromium/net/tools/quic/quic_server.h @@ -18,8 +18,6 @@ namespace net { -class QuicCryptoServerConfig; - namespace tools { class QuicDispatcher; @@ -67,6 +65,10 @@ class QuicServer : public EpollCallbackInterface { const IPEndPoint& server_address, const IPEndPoint& client_address); + void SetStrikeRegisterNoStartupPeriod() { + crypto_config_.set_strike_register_no_startup_period(); + } + bool overflow_supported() { return overflow_supported_; } int packets_dropped() { return packets_dropped_; } diff --git a/chromium/net/tools/quic/quic_server_session.h b/chromium/net/tools/quic/quic_server_session.h index 604b9fc0c42..2f031373bda 100644 --- a/chromium/net/tools/quic/quic_server_session.h +++ b/chromium/net/tools/quic/quic_server_session.h @@ -25,6 +25,10 @@ class ReliableQuicStream; namespace tools { +namespace test { +class QuicServerSessionPeer; +} // namespace test + // An interface from the session to the entity owning the session. // This lets the session notify its owner (the Dispatcher) when the connection // is closed. @@ -66,6 +70,8 @@ class QuicServerSession : public QuicSession { const QuicCryptoServerConfig& crypto_config); private: + friend class test::QuicServerSessionPeer; + scoped_ptr<QuicCryptoServerStream> crypto_stream_; QuicSessionOwner* owner_; diff --git a/chromium/net/tools/quic/quic_server_session_test.cc b/chromium/net/tools/quic/quic_server_session_test.cc new file mode 100644 index 00000000000..a6f94abd7ba --- /dev/null +++ b/chromium/net/tools/quic/quic_server_session_test.cc @@ -0,0 +1,255 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/quic/quic_server_session.h" + + +#include "net/quic/crypto/crypto_server_config.h" +#include "net/quic/crypto/quic_random.h" +#include "net/quic/quic_connection.h" +#include "net/quic/test_tools/quic_connection_peer.h" +#include "net/quic/test_tools/quic_test_utils.h" +#include "net/quic/test_tools/reliable_quic_stream_peer.h" +#include "net/tools/flip_server/epoll_server.h" +#include "net/tools/quic/quic_spdy_server_stream.h" +#include "net/tools/quic/test_tools/quic_test_utils.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using __gnu_cxx::vector; +using net::test::MockConnection; +using net::test::QuicConnectionPeer; +using net::test::ReliableQuicStreamPeer; +using testing::_; +using testing::StrictMock; + +namespace net { +namespace tools { +namespace test { + +class QuicServerSessionPeer { + public: + static ReliableQuicStream* GetIncomingReliableStream( + QuicServerSession* s, QuicStreamId id) { + return s->GetIncomingReliableStream(id); + } + static ReliableQuicStream* GetStream(QuicServerSession* s, QuicStreamId id) { + return s->GetStream(id); + } +}; + +class CloseOnDataStream : public ReliableQuicStream { + public: + CloseOnDataStream(QuicStreamId id, QuicSession* session) + : ReliableQuicStream(id, session) { + } + + virtual bool OnStreamFrame(const QuicStreamFrame& frame) OVERRIDE { + session()->MarkDecompressionBlocked(1, id()); + session()->CloseStream(id()); + return true; + } + + virtual uint32 ProcessData(const char* data, uint32 data_len) OVERRIDE { + return 0; + } +}; + +class TestQuicQuicServerSession : public QuicServerSession { + public: + TestQuicQuicServerSession(const QuicConfig& config, + QuicConnection* connection, + QuicSessionOwner* owner) + : QuicServerSession(config, connection, owner), + close_stream_on_data_(false) { + } + + virtual ReliableQuicStream* CreateIncomingReliableStream( + QuicStreamId id) OVERRIDE { + if (!ShouldCreateIncomingReliableStream(id)) { + return NULL; + } + if (close_stream_on_data_) { + return new CloseOnDataStream(id, this); + } else { + return new QuicSpdyServerStream(id, this); + } + } + + void CloseStreamOnData() { + close_stream_on_data_ = true; + } + + private: + bool close_stream_on_data_; +}; + +namespace { + +class QuicServerSessionTest : public ::testing::Test { + protected: + QuicServerSessionTest() + : guid_(1), + crypto_config_(QuicCryptoServerConfig::TESTING, + QuicRandom::GetInstance()) { + config_.SetDefaults(); + config_.set_max_streams_per_connection(3, 3); + + connection_ = new MockConnection(guid_, IPEndPoint(), 0, &eps_, true); + session_.reset(new TestQuicQuicServerSession( + config_, connection_, &owner_)); + session_->InitializeSession(crypto_config_); + visitor_ = QuicConnectionPeer::GetVisitor(connection_); + } + + void MarkHeadersReadForStream(QuicStreamId id) { + ReliableQuicStream* stream = QuicServerSessionPeer::GetStream( + session_.get(), id); + ASSERT_TRUE(stream != NULL); + ReliableQuicStreamPeer::SetHeadersDecompressed(stream, true); + } + + QuicGuid guid_; + EpollServer eps_; + StrictMock<MockQuicSessionOwner> owner_; + MockConnection* connection_; + QuicConfig config_; + QuicCryptoServerConfig crypto_config_; + scoped_ptr<TestQuicQuicServerSession> session_; + QuicConnectionVisitorInterface* visitor_; +}; + +TEST_F(QuicServerSessionTest, CloseStreamDueToReset) { + // Open a stream, then reset it. + // Send two bytes of payload to open it. + QuicPacketHeader header; + header.public_header.guid = guid_; + header.public_header.reset_flag = false; + header.public_header.version_flag = false; + QuicStreamFrame data1(3, false, 0, "HT"); + vector<QuicStreamFrame> frames; + frames.push_back(data1); + EXPECT_TRUE(visitor_->OnStreamFrames(frames)); + EXPECT_EQ(1u, session_->GetNumOpenStreams()); + + // Pretend we got full headers, so we won't trigger the 'unrecoverable + // compression context' state. + MarkHeadersReadForStream(3); + + // Send a reset. + QuicRstStreamFrame rst1(3, QUIC_STREAM_NO_ERROR); + visitor_->OnRstStream(rst1); + EXPECT_EQ(0u, session_->GetNumOpenStreams()); + + // Send the same two bytes of payload in a new packet. + EXPECT_TRUE(visitor_->OnStreamFrames(frames)); + + // The stream should not be re-opened. + EXPECT_EQ(0u, session_->GetNumOpenStreams()); +} + +TEST_F(QuicServerSessionTest, NeverOpenStreamDueToReset) { + // Send a reset. + QuicRstStreamFrame rst1(3, QUIC_STREAM_NO_ERROR); + visitor_->OnRstStream(rst1); + EXPECT_EQ(0u, session_->GetNumOpenStreams()); + + // Send two bytes of payload. + QuicPacketHeader header; + header.public_header.guid = guid_; + header.public_header.reset_flag = false; + header.public_header.version_flag = false; + QuicStreamFrame data1(3, false, 0, "HT"); + vector<QuicStreamFrame> frames; + frames.push_back(data1); + + // When we get data for the closed stream, it implies the far side has + // compressed some headers. As a result we're going to bail due to + // unrecoverable compression context state. + EXPECT_CALL(*connection_, SendConnectionClose( + QUIC_STREAM_RST_BEFORE_HEADERS_DECOMPRESSED)); + EXPECT_FALSE(visitor_->OnStreamFrames(frames)); + + // The stream should never be opened, now that the reset is received. + EXPECT_EQ(0u, session_->GetNumOpenStreams()); +} + +TEST_F(QuicServerSessionTest, GoOverPrematureClosedStreamLimit) { + QuicPacketHeader header; + header.public_header.guid = guid_; + header.public_header.reset_flag = false; + header.public_header.version_flag = false; + QuicStreamFrame data1(3, false, 0, "H"); + vector<QuicStreamFrame> frames; + frames.push_back(data1); + + // Set up the stream such that it's open in OnPacket, but closes half way + // through while on the decompression blocked list. + session_->CloseStreamOnData(); + + EXPECT_CALL(*connection_, SendConnectionClose( + QUIC_STREAM_RST_BEFORE_HEADERS_DECOMPRESSED)); + EXPECT_FALSE(visitor_->OnStreamFrames(frames)); +} + +TEST_F(QuicServerSessionTest, AcceptClosedStream) { + QuicPacketHeader header; + header.public_header.guid = guid_; + header.public_header.reset_flag = false; + header.public_header.version_flag = false; + vector<QuicStreamFrame> frames; + // Send (empty) compressed headers followed by two bytes of data. + frames.push_back(QuicStreamFrame(3, false, 0, "\1\0\0\0\0\0\0\0HT")); + frames.push_back(QuicStreamFrame(5, false, 0, "\2\0\0\0\0\0\0\0HT")); + EXPECT_TRUE(visitor_->OnStreamFrames(frames)); + + // Pretend we got full headers, so we won't trigger the 'unercoverable + // compression context' state. + MarkHeadersReadForStream(3); + + // Send a reset. + QuicRstStreamFrame rst(3, QUIC_STREAM_NO_ERROR); + visitor_->OnRstStream(rst); + + // If we were tracking, we'd probably want to reject this because it's data + // past the reset point of stream 3. As it's a closed stream we just drop the + // data on the floor, but accept the packet because it has data for stream 5. + frames.clear(); + frames.push_back(QuicStreamFrame(3, false, 2, "TP")); + frames.push_back(QuicStreamFrame(5, false, 2, "TP")); + EXPECT_TRUE(visitor_->OnStreamFrames(frames)); +} + +TEST_F(QuicServerSessionTest, MaxNumConnections) { + EXPECT_EQ(0u, session_->GetNumOpenStreams()); + EXPECT_TRUE( + QuicServerSessionPeer::GetIncomingReliableStream(session_.get(), 3)); + EXPECT_TRUE( + QuicServerSessionPeer::GetIncomingReliableStream(session_.get(), 5)); + EXPECT_TRUE( + QuicServerSessionPeer::GetIncomingReliableStream(session_.get(), 7)); + EXPECT_FALSE( + QuicServerSessionPeer::GetIncomingReliableStream(session_.get(), 9)); +} + +TEST_F(QuicServerSessionTest, MaxNumConnectionsImplicit) { + EXPECT_EQ(0u, session_->GetNumOpenStreams()); + EXPECT_TRUE( + QuicServerSessionPeer::GetIncomingReliableStream(session_.get(), 3)); + // Implicitly opens two more streams before 9. + EXPECT_FALSE( + QuicServerSessionPeer::GetIncomingReliableStream(session_.get(), 9)); +} + +TEST_F(QuicServerSessionTest, GetEvenIncomingError) { + // Incoming streams on the server session must be odd. + EXPECT_EQ(NULL, + QuicServerSessionPeer::GetIncomingReliableStream( + session_.get(), 2)); +} + +} // namespace +} // namespace test +} // namespace tools +} // namespace net diff --git a/chromium/net/tools/quic/quic_spdy_client_stream.cc b/chromium/net/tools/quic/quic_spdy_client_stream.cc index 62949534b3d..368f93eb855 100644 --- a/chromium/net/tools/quic/quic_spdy_client_stream.cc +++ b/chromium/net/tools/quic/quic_spdy_client_stream.cc @@ -62,8 +62,13 @@ ssize_t QuicSpdyClientStream::SendRequest(const BalsaHeaders& headers, SpdyHeaderBlock header_block = SpdyUtils::RequestHeadersToSpdyHeaders(headers); - string headers_string = - session()->compressor()->CompressHeaders(header_block); + string headers_string; + if (session()->connection()->version() >= QUIC_VERSION_9) { + headers_string = session()->compressor()->CompressHeadersWithPriority( + priority(), header_block); + } else { + headers_string = session()->compressor()->CompressHeaders(header_block); + } bool has_body = !body.empty(); diff --git a/chromium/net/tools/quic/quic_spdy_client_stream.h b/chromium/net/tools/quic/quic_spdy_client_stream.h index ec4d25747f7..5d32b30d3cb 100644 --- a/chromium/net/tools/quic/quic_spdy_client_stream.h +++ b/chromium/net/tools/quic/quic_spdy_client_stream.h @@ -34,6 +34,10 @@ class QuicSpdyClientStream : public QuicReliableClientStream { base::StringPiece body, bool fin) OVERRIDE; + // While the server's set_priority shouldn't be called externally, the creator + // of client-side streams should be able to set the priority. + using QuicReliableClientStream::set_priority; + private: int ParseResponseHeaders(); diff --git a/chromium/net/tools/quic/quic_spdy_server_stream.cc b/chromium/net/tools/quic/quic_spdy_server_stream.cc index d6f3b7590ff..2ebe1dc64c6 100644 --- a/chromium/net/tools/quic/quic_spdy_server_stream.cc +++ b/chromium/net/tools/quic/quic_spdy_server_stream.cc @@ -69,10 +69,11 @@ void QuicSpdyServerStream::SendHeaders( const BalsaHeaders& response_headers) { SpdyHeaderBlock header_block = SpdyUtils::ResponseHeadersToSpdyHeaders(response_headers); - string headers = - session()->compressor()->CompressHeaders(header_block); - WriteData(headers, false); + string headers_string; + headers_string = session()->compressor()->CompressHeaders(header_block); + + WriteData(headers_string, false); } int QuicSpdyServerStream::ParseRequestHeaders() { diff --git a/chromium/net/tools/quic/quic_time_wait_list_manager_test.cc b/chromium/net/tools/quic/quic_time_wait_list_manager_test.cc new file mode 100644 index 00000000000..8f596874370 --- /dev/null +++ b/chromium/net/tools/quic/quic_time_wait_list_manager_test.cc @@ -0,0 +1,397 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/tools/quic/quic_time_wait_list_manager.h" + +#include <errno.h> + +#include "net/quic/crypto/crypto_protocol.h" +#include "net/quic/crypto/null_encrypter.h" +#include "net/quic/crypto/quic_decrypter.h" +#include "net/quic/crypto/quic_encrypter.h" +#include "net/quic/quic_data_reader.h" +#include "net/quic/quic_framer.h" +#include "net/quic/quic_protocol.h" +#include "net/quic/test_tools/quic_test_utils.h" +#include "net/tools/quic/quic_packet_writer.h" +#include "net/tools/quic/test_tools/mock_epoll_server.h" +#include "net/tools/quic/test_tools/quic_test_utils.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" + +using net::test::FramerVisitorCapturingPublicReset; +using testing::_; +using testing::Args; +using testing::Matcher; +using testing::MatcherInterface; +using testing::Return; +using testing::SetArgPointee; +using testing::Truly; + +namespace net { +namespace tools { +namespace test { +namespace { + +class TestTimeWaitListManager : public QuicTimeWaitListManager { + public: + TestTimeWaitListManager(QuicPacketWriter* writer, + EpollServer* epoll_server) + : QuicTimeWaitListManager(writer, epoll_server) { + } + + using QuicTimeWaitListManager::is_write_blocked; + using QuicTimeWaitListManager::time_wait_period; + using QuicTimeWaitListManager::ShouldSendPublicReset; + using QuicTimeWaitListManager::GetQuicVersionFromGuid; +}; + +class MockFakeTimeEpollServer : public FakeTimeEpollServer { + public: + MOCK_METHOD2(RegisterAlarm, void(int64 timeout_in_us, + EpollAlarmCallbackInterface* alarm)); +}; + +class QuicTimeWaitListManagerTest : public testing::Test { + protected: + QuicTimeWaitListManagerTest() + : time_wait_list_manager_(&writer_, &epoll_server_), + framer_(QuicVersionMax(), + QuicTime::Zero(), + true), + guid_(45) { + } + + void AddGuid(QuicGuid guid) { + time_wait_list_manager_.AddGuidToTimeWait(guid, QuicVersionMax()); + } + + void AddGuid(QuicGuid guid, QuicVersion version) { + time_wait_list_manager_.AddGuidToTimeWait(guid, version); + } + + bool IsGuidInTimeWait(QuicGuid guid) { + return time_wait_list_manager_.IsGuidInTimeWait(guid); + } + + void ProcessPacket(QuicGuid guid, const QuicEncryptedPacket& packet) { + time_wait_list_manager_.ProcessPacket(server_address_, + client_address_, + guid, + packet); + } + + QuicEncryptedPacket* ConstructEncryptedPacket( + QuicGuid guid, + QuicPacketSequenceNumber sequence_number) { + QuicPacketHeader header; + header.public_header.guid = guid; + header.public_header.guid_length = PACKET_8BYTE_GUID; + header.public_header.version_flag = false; + header.public_header.reset_flag = false; + header.public_header.sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER; + header.packet_sequence_number = sequence_number; + header.entropy_flag = false; + header.entropy_hash = 0; + header.fec_flag = false; + header.is_in_fec_group = NOT_IN_FEC_GROUP; + header.fec_group = 0; + QuicStreamFrame stream_frame(1, false, 0, "data"); + QuicFrame frame(&stream_frame); + QuicFrames frames; + frames.push_back(frame); + scoped_ptr<QuicPacket> packet( + framer_.BuildUnsizedDataPacket(header, frames).packet); + EXPECT_TRUE(packet != NULL); + QuicEncryptedPacket* encrypted = framer_.EncryptPacket(ENCRYPTION_NONE, + sequence_number, + *packet); + EXPECT_TRUE(encrypted != NULL); + return encrypted; + } + + MockFakeTimeEpollServer epoll_server_; + MockPacketWriter writer_; + TestTimeWaitListManager time_wait_list_manager_; + QuicFramer framer_; + QuicGuid guid_; + IPEndPoint server_address_; + IPEndPoint client_address_; +}; + +class ValidatePublicResetPacketPredicate + : public MatcherInterface<const std::tr1::tuple<const char*, int> > { + public: + explicit ValidatePublicResetPacketPredicate(QuicGuid guid, + QuicPacketSequenceNumber number) + : guid_(guid), sequence_number_(number) { + } + + virtual bool MatchAndExplain( + const std::tr1::tuple<const char*, int> packet_buffer, + testing::MatchResultListener* /* listener */) const { + FramerVisitorCapturingPublicReset visitor; + QuicFramer framer(QuicVersionMax(), + QuicTime::Zero(), + false); + framer.set_visitor(&visitor); + QuicEncryptedPacket encrypted(std::tr1::get<0>(packet_buffer), + std::tr1::get<1>(packet_buffer)); + framer.ProcessPacket(encrypted); + QuicPublicResetPacket packet = visitor.public_reset_packet(); + return guid_ == packet.public_header.guid && + packet.public_header.reset_flag && !packet.public_header.version_flag && + sequence_number_ == packet.rejected_sequence_number; + } + + virtual void DescribeTo(::std::ostream* os) const { } + + virtual void DescribeNegationTo(::std::ostream* os) const { } + + private: + QuicGuid guid_; + QuicPacketSequenceNumber sequence_number_; +}; + +void ValidPublicResetPacketPredicate( + QuicGuid expected_guid, + QuicPacketSequenceNumber expected_sequence_number, + const std::tr1::tuple<const char*, int>& packet_buffer) { + FramerVisitorCapturingPublicReset visitor; + QuicFramer framer(QuicVersionMax(), + QuicTime::Zero(), + false); + framer.set_visitor(&visitor); + QuicEncryptedPacket encrypted(std::tr1::get<0>(packet_buffer), + std::tr1::get<1>(packet_buffer)); + framer.ProcessPacket(encrypted); + QuicPublicResetPacket packet = visitor.public_reset_packet(); + EXPECT_EQ(expected_guid, packet.public_header.guid); + EXPECT_TRUE(packet.public_header.reset_flag); + EXPECT_FALSE(packet.public_header.version_flag); + EXPECT_EQ(expected_sequence_number, packet.rejected_sequence_number); +} + + +Matcher<const std::tr1::tuple<const char*, int> > PublicResetPacketEq( + QuicGuid guid, + QuicPacketSequenceNumber sequence_number) { + return MakeMatcher(new ValidatePublicResetPacketPredicate(guid, + sequence_number)); +} + +TEST_F(QuicTimeWaitListManagerTest, CheckGuidInTimeWait) { + EXPECT_FALSE(IsGuidInTimeWait(guid_)); + AddGuid(guid_); + EXPECT_TRUE(IsGuidInTimeWait(guid_)); +} + +TEST_F(QuicTimeWaitListManagerTest, SendPublicReset) { + AddGuid(guid_); + const int kRandomSequenceNumber = 1; + scoped_ptr<QuicEncryptedPacket> packet( + ConstructEncryptedPacket(guid_, kRandomSequenceNumber)); + EXPECT_CALL(writer_, WritePacket(_, _, + server_address_.address(), + client_address_, + &time_wait_list_manager_, + _)) + .With(Args<0, 1>(PublicResetPacketEq(guid_, + kRandomSequenceNumber))) + .WillOnce(Return(packet->length())); + + ProcessPacket(guid_, *packet); +} + +TEST_F(QuicTimeWaitListManagerTest, DropInvalidPacket) { + AddGuid(guid_); + const char buffer[] = "invalid"; + QuicEncryptedPacket packet(buffer, arraysize(buffer)); + ProcessPacket(guid_, packet); + // Will get called for a valid packet since received packet count = 1 (2 ^ 0). + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _, _)).Times(0); +} + +TEST_F(QuicTimeWaitListManagerTest, DropPublicResetPacket) { + AddGuid(guid_); + QuicPublicResetPacket packet; + packet.public_header.guid = guid_; + packet.public_header.version_flag = false; + packet.public_header.reset_flag = true; + packet.rejected_sequence_number = 239191; + packet.nonce_proof = 1010101; + scoped_ptr<QuicEncryptedPacket> public_reset_packet( + QuicFramer::BuildPublicResetPacket(packet)); + ProcessPacket(guid_, *public_reset_packet); + // Will get called for a data packet since received packet count = 1 (2 ^ 0). + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _, _)) + .Times(0); +} + +TEST_F(QuicTimeWaitListManagerTest, SendPublicResetWithExponentialBackOff) { + AddGuid(guid_); + for (int sequence_number = 1; sequence_number < 101; ++sequence_number) { + scoped_ptr<QuicEncryptedPacket> packet( + ConstructEncryptedPacket(guid_, sequence_number)); + if ((sequence_number & (sequence_number - 1)) == 0) { + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _, _)) + .WillOnce(Return(1)); + } + ProcessPacket(guid_, *packet); + // Send public reset with exponential back off. + if ((sequence_number & (sequence_number - 1)) == 0) { + EXPECT_TRUE( + time_wait_list_manager_.ShouldSendPublicReset(sequence_number)); + } else { + EXPECT_FALSE( + time_wait_list_manager_.ShouldSendPublicReset(sequence_number)); + } + } +} + +TEST_F(QuicTimeWaitListManagerTest, CleanUpOldGuids) { + const int kGuidCount = 100; + const int kOldGuidCount = 31; + + // Add guids such that their expiry time is kTimeWaitPeriod_. + epoll_server_.set_now_in_usec(0); + for (int guid = 1; guid <= kOldGuidCount; ++guid) { + AddGuid(guid); + } + + // Add remaining guids such that their add time is 2 * kTimeWaitPeriod. + const QuicTime::Delta time_wait_period = + time_wait_list_manager_.time_wait_period(); + epoll_server_.set_now_in_usec(time_wait_period.ToMicroseconds()); + for (int guid = kOldGuidCount + 1; guid <= kGuidCount; ++guid) { + AddGuid(guid); + } + + QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39); + // Now set the current time as time_wait_period + offset usecs. + epoll_server_.set_now_in_usec(time_wait_period.Add(offset).ToMicroseconds()); + // After all the old guids are cleaned up, check the next alarm interval. + int64 next_alarm_time = epoll_server_.ApproximateNowInUsec() + + time_wait_period.Subtract(offset).ToMicroseconds(); + EXPECT_CALL(epoll_server_, RegisterAlarm(next_alarm_time, _)); + + time_wait_list_manager_.CleanUpOldGuids(); + for (int guid = 1; guid <= kGuidCount; ++guid) { + EXPECT_EQ(guid > kOldGuidCount, IsGuidInTimeWait(guid)) + << "kOldGuidCount: " << kOldGuidCount + << " guid: " << guid; + } +} + +TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) { + QuicGuid guid = 1; + AddGuid(guid); + QuicPacketSequenceNumber sequence_number = 234; + scoped_ptr<QuicEncryptedPacket> packet( + ConstructEncryptedPacket(guid, sequence_number)); + // Let first write through. + EXPECT_CALL(writer_, WritePacket(_, _, + server_address_.address(), + client_address_, + &time_wait_list_manager_, + _)) + .With(Args<0, 1>(PublicResetPacketEq(guid, + sequence_number))) + .WillOnce(Return(packet->length())); + ProcessPacket(guid, *packet); + EXPECT_FALSE(time_wait_list_manager_.is_write_blocked()); + + // write block for the next packet. + EXPECT_CALL(writer_, WritePacket(_, _, + server_address_.address(), + client_address_, + &time_wait_list_manager_, + _)) + .With(Args<0, 1>(PublicResetPacketEq(guid, + sequence_number))) + .WillOnce(DoAll(SetArgPointee<5>(EAGAIN), Return(-1))); + ProcessPacket(guid, *packet); + // 3rd packet. No public reset should be sent; + ProcessPacket(guid, *packet); + EXPECT_TRUE(time_wait_list_manager_.is_write_blocked()); + + // write packet should not be called since already write blocked but the + // should be queued. + QuicGuid other_guid = 2; + AddGuid(other_guid); + QuicPacketSequenceNumber other_sequence_number = 23423; + scoped_ptr<QuicEncryptedPacket> other_packet( + ConstructEncryptedPacket(other_guid, other_sequence_number)); + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _, _)) + .Times(0); + ProcessPacket(other_guid, *other_packet); + + // Now expect all the write blocked public reset packets to be sent again. + EXPECT_CALL(writer_, WritePacket(_, _, + server_address_.address(), + client_address_, + &time_wait_list_manager_, + _)) + .With(Args<0, 1>(PublicResetPacketEq(guid, + sequence_number))) + .WillOnce(Return(packet->length())); + EXPECT_CALL(writer_, WritePacket(_, _, + server_address_.address(), + client_address_, + &time_wait_list_manager_, + _)) + .With(Args<0, 1>(PublicResetPacketEq(other_guid, + other_sequence_number))) + .WillOnce(Return(other_packet->length())); + time_wait_list_manager_.OnCanWrite(); + EXPECT_FALSE(time_wait_list_manager_.is_write_blocked()); +} + +TEST_F(QuicTimeWaitListManagerTest, MakeSureFramerUsesCorrectVersion) { + const int kRandomSequenceNumber = 1; + scoped_ptr<QuicEncryptedPacket> packet; + + AddGuid(guid_, QuicVersionMin()); + framer_.set_version(QuicVersionMin()); + packet.reset(ConstructEncryptedPacket(guid_, kRandomSequenceNumber)); + + // Reset packet should be written, using the minimum quic version. + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _, _)).Times(1); + ProcessPacket(guid_, *packet); + EXPECT_EQ(time_wait_list_manager_.version(), QuicVersionMin()); + + // New guid + ++guid_; + + AddGuid(guid_, QuicVersionMax()); + framer_.set_version(QuicVersionMax()); + packet.reset(ConstructEncryptedPacket(guid_, kRandomSequenceNumber)); + + // Reset packet should be written, using the maximum quic version. + EXPECT_CALL(writer_, WritePacket(_, _, _, _, _, _)).Times(1); + ProcessPacket(guid_, *packet); + EXPECT_EQ(time_wait_list_manager_.version(), QuicVersionMax()); +} + +TEST_F(QuicTimeWaitListManagerTest, GetQuicVersionFromMap) { + const int kGuid1 = 123; + const int kGuid2 = 456; + const int kGuid3 = 789; + + AddGuid(kGuid1, QuicVersionMin()); + AddGuid(kGuid2, QuicVersionMax()); + AddGuid(kGuid3, QuicVersionMax()); + + EXPECT_EQ(QuicVersionMin(), + time_wait_list_manager_.GetQuicVersionFromGuid(kGuid1)); + EXPECT_EQ(QuicVersionMax(), + time_wait_list_manager_.GetQuicVersionFromGuid(kGuid2)); + EXPECT_EQ(QuicVersionMax(), + time_wait_list_manager_.GetQuicVersionFromGuid(kGuid3)); +} + +} // namespace +} // namespace test +} // namespace tools +} // namespace net diff --git a/chromium/net/tools/quic/test_tools/http_message_test_utils.cc b/chromium/net/tools/quic/test_tools/http_message_test_utils.cc index 7d6df7a7649..70eb59290fa 100644 --- a/chromium/net/tools/quic/test_tools/http_message_test_utils.cc +++ b/chromium/net/tools/quic/test_tools/http_message_test_utils.cc @@ -54,7 +54,6 @@ const char* kMethodString[] = { // - Neither Transfer-Encoding nor Content-Length is present and message // is tagged as complete. bool IsCompleteMessage(const HTTPMessage& message) { - return true; const BalsaHeaders* headers = message.headers(); StringPiece content_length = headers->GetHeader(kContentLength); if (!content_length.empty()) { diff --git a/chromium/net/tools/quic/test_tools/quic_client_peer.cc b/chromium/net/tools/quic/test_tools/quic_client_peer.cc index 858359474c8..25fdb7eedc5 100644 --- a/chromium/net/tools/quic/test_tools/quic_client_peer.cc +++ b/chromium/net/tools/quic/test_tools/quic_client_peer.cc @@ -11,13 +11,6 @@ namespace tools { namespace test { // static -void QuicClientPeer::Reinitialize(QuicClient* client) { - client->initialized_ = false; - client->epoll_server_.UnregisterFD(client->fd_); - client->Initialize(); -} - -// static int QuicClientPeer::GetFd(QuicClient* client) { return client->fd_; } diff --git a/chromium/net/tools/quic/test_tools/quic_client_peer.h b/chromium/net/tools/quic/test_tools/quic_client_peer.h index 8eaa17e675d..016120aa8bf 100644 --- a/chromium/net/tools/quic/test_tools/quic_client_peer.h +++ b/chromium/net/tools/quic/test_tools/quic_client_peer.h @@ -14,7 +14,6 @@ namespace test { class QuicClientPeer { public: - static void Reinitialize(QuicClient* client); static int GetFd(QuicClient* client); }; diff --git a/chromium/net/tools/quic/test_tools/quic_test_client.cc b/chromium/net/tools/quic/test_tools/quic_test_client.cc index 859d7a56044..272786ebf45 100644 --- a/chromium/net/tools/quic/test_tools/quic_test_client.cc +++ b/chromium/net/tools/quic/test_tools/quic_test_client.cc @@ -11,6 +11,7 @@ #include "net/quic/crypto/proof_verifier.h" #include "net/tools/flip_server/balsa_headers.h" #include "net/tools/quic/quic_epoll_connection_helper.h" +#include "net/tools/quic/quic_spdy_client_stream.h" #include "net/tools/quic/test_tools/http_message_test_utils.h" #include "url/gurl.h" @@ -26,7 +27,6 @@ class RecordingProofVerifier : public net::ProofVerifier { public: // ProofVerifier interface. virtual net::ProofVerifier::Status VerifyProof( - net::QuicVersion version, const string& hostname, const string& server_config, const vector<string>& certs, @@ -158,6 +158,7 @@ void QuicTestClient::Initialize(IPEndPoint address, server_address_ = address; stream_ = NULL; stream_error_ = QUIC_STREAM_NO_ERROR; + priority_ = 3; bytes_read_ = 0; bytes_written_= 0; never_connected_ = true; @@ -244,10 +245,13 @@ QuicReliableClientStream* QuicTestClient::GetOrCreateStream() { } if (!stream_) { stream_ = client_->CreateReliableClientStream(); - if (stream_ != NULL) { - stream_->set_visitor(this); + if (stream_ == NULL) { + return NULL; } + stream_->set_visitor(this); + reinterpret_cast<QuicSpdyClientStream*>(stream_)->set_priority(priority_); } + return stream_; } diff --git a/chromium/net/tools/quic/test_tools/quic_test_client.h b/chromium/net/tools/quic/test_tools/quic_test_client.h index 74bfc24646a..3cd71d59f81 100644 --- a/chromium/net/tools/quic/test_tools/quic_test_client.h +++ b/chromium/net/tools/quic/test_tools/quic_test_client.h @@ -107,6 +107,8 @@ class QuicTestClient : public ReliableQuicStream::Visitor { void set_auto_reconnect(bool reconnect) { auto_reconnect_ = reconnect; } + void set_priority(QuicPriority priority) { priority_ = priority; } + private: void Initialize(IPEndPoint address, const string& hostname, bool secure); @@ -118,6 +120,8 @@ class QuicTestClient : public ReliableQuicStream::Visitor { QuicRstStreamErrorCode stream_error_; BalsaHeaders headers_; + QuicPriority priority_; + string response_; uint64 bytes_read_; uint64 bytes_written_; diff --git a/chromium/net/tools/quic/test_tools/quic_test_utils.cc b/chromium/net/tools/quic/test_tools/quic_test_utils.cc index 95f1fb215ea..fa3627ba74f 100644 --- a/chromium/net/tools/quic/test_tools/quic_test_utils.cc +++ b/chromium/net/tools/quic/test_tools/quic_test_utils.cc @@ -50,6 +50,13 @@ void MockConnection::AdvanceTime(QuicTime::Delta delta) { static_cast<MockHelper*>(helper())->AdvanceTime(delta); } + +MockQuicSessionOwner::MockQuicSessionOwner() { +} + +MockQuicSessionOwner::~MockQuicSessionOwner() { +} + bool TestDecompressorVisitor::OnDecompressedData(StringPiece data) { data.AppendToString(&data_); return true; @@ -76,6 +83,18 @@ QuicCryptoStream* TestSession::GetCryptoStream() { return crypto_stream_; } +MockAckNotifierDelegate::MockAckNotifierDelegate() { +} + +MockAckNotifierDelegate::~MockAckNotifierDelegate() { +} + +MockPacketWriter::MockPacketWriter() { +} + +MockPacketWriter::~MockPacketWriter() { +} + } // namespace test } // namespace tools } // namespace net diff --git a/chromium/net/tools/quic/test_tools/quic_test_utils.h b/chromium/net/tools/quic/test_tools/quic_test_utils.h index 31ea1815e06..dffa2558af3 100644 --- a/chromium/net/tools/quic/test_tools/quic_test_utils.h +++ b/chromium/net/tools/quic/test_tools/quic_test_utils.h @@ -12,6 +12,8 @@ #include "net/quic/quic_session.h" #include "net/quic/quic_spdy_decompressor.h" #include "net/spdy/spdy_framer.h" +#include "net/tools/quic/quic_packet_writer.h" +#include "net/tools/quic/quic_server_session.h" #include "testing/gmock/include/gmock/gmock.h" namespace net { @@ -22,8 +24,6 @@ class IPEndPoint; namespace tools { namespace test { -std::string SerializeUncompressedHeaders(const SpdyHeaderBlock& headers); - class MockConnection : public QuicConnection { public: // Uses a QuicConnectionHelper created with fd and eps. @@ -71,6 +71,13 @@ class MockConnection : public QuicConnection { DISALLOW_COPY_AND_ASSIGN(MockConnection); }; +class MockQuicSessionOwner : public QuicSessionOwner { + public: + MockQuicSessionOwner(); + ~MockQuicSessionOwner(); + MOCK_METHOD2(OnConnectionClose, void(QuicGuid guid, QuicErrorCode error)); +}; + class TestDecompressorVisitor : public QuicSpdyDecompressor::Visitor { public: virtual ~TestDecompressorVisitor() {} @@ -105,6 +112,27 @@ class TestSession : public QuicSession { DISALLOW_COPY_AND_ASSIGN(TestSession); }; +class MockAckNotifierDelegate : public QuicAckNotifier::DelegateInterface { + public: + MockAckNotifierDelegate(); + virtual ~MockAckNotifierDelegate(); + + MOCK_METHOD0(OnAckNotification, void()); +}; + +class MockPacketWriter : public QuicPacketWriter { + public: + MockPacketWriter(); + virtual ~MockPacketWriter(); + + MOCK_METHOD6(WritePacket, int(const char* buffer, + size_t buf_len, + const IPAddressNumber& self_address, + const IPEndPoint& peer_address, + QuicBlockedWriterInterface* blocked_writer, + int* error)); +}; + } // namespace test } // namespace tools } // namespace net diff --git a/chromium/net/tools/testserver/testserver.py b/chromium/net/tools/testserver/testserver.py index 77a31426300..e0317e88685 100755 --- a/chromium/net/tools/testserver/testserver.py +++ b/chromium/net/tools/testserver/testserver.py @@ -21,6 +21,7 @@ import hashlib import logging import minica import os +import json import random import re import select @@ -282,7 +283,8 @@ class TestPageHandler(testserver_base.BasePageHandler): post_handlers = [ self.EchoTitleHandler, self.EchoHandler, - self.PostOnlyFileHandler] + get_handlers + self.PostOnlyFileHandler, + self.EchoMultipartPostHandler] + get_handlers put_handlers = [ self.EchoTitleHandler, self.EchoHandler] + get_handlers @@ -298,6 +300,7 @@ class TestPageHandler(testserver_base.BasePageHandler): 'jpg' : 'image/jpeg', 'json': 'application/json', 'pdf' : 'application/pdf', + 'txt' : 'text/plain', 'wav' : 'audio/wav', 'xml' : 'text/xml' } @@ -662,6 +665,37 @@ class TestPageHandler(testserver_base.BasePageHandler): self.wfile.write('</body></html>') return True + def EchoMultipartPostHandler(self): + """This handler echoes received multipart post data as json format.""" + + if not (self._ShouldHandleRequest("/echomultipartpost") or + self._ShouldHandleRequest("/searchbyimage")): + return False + + content_type, parameters = cgi.parse_header( + self.headers.getheader('content-type')) + if content_type == 'multipart/form-data': + post_multipart = cgi.parse_multipart(self.rfile, parameters) + elif content_type == 'application/x-www-form-urlencoded': + raise Exception('POST by application/x-www-form-urlencoded is ' + 'not implemented.') + else: + post_multipart = {} + + # Since the data can be binary, we encode them by base64. + post_multipart_base64_encoded = {} + for field, values in post_multipart.items(): + post_multipart_base64_encoded[field] = [base64.b64encode(value) + for value in values] + + result = {'POST_multipart' : post_multipart_base64_encoded} + + self.send_response(200) + self.send_header("Content-type", "text/plain") + self.end_headers() + self.wfile.write(json.dumps(result, indent=2, sort_keys=False)) + return True + def DownloadHandler(self): """This handler sends a downloadable file with or without reporting the size (6K).""" @@ -1314,7 +1348,7 @@ class TestPageHandler(testserver_base.BasePageHandler): if query_char < 0 or len(self.path) <= query_char + 1: self.sendRedirectHelp(test_name) return True - dest = self.path[query_char + 1:] + dest = urllib.unquote(self.path[query_char + 1:]) self.send_response(301) # moved permanently self.send_header('Location', dest) @@ -1338,7 +1372,7 @@ class TestPageHandler(testserver_base.BasePageHandler): if query_char < 0 or len(self.path) <= query_char + 1: self.sendRedirectHelp(test_name) return True - dest = self.path[query_char + 1:] + dest = urllib.unquote(self.path[query_char + 1:]) self.send_response(200) self.send_header('Content-Type', 'text/html') diff --git a/chromium/net/tools/tld_cleanup/tld_cleanup.cc b/chromium/net/tools/tld_cleanup/tld_cleanup.cc index 9d5337c6d7c..a4b127bdf2c 100644 --- a/chromium/net/tools/tld_cleanup/tld_cleanup.cc +++ b/chromium/net/tools/tld_cleanup/tld_cleanup.cc @@ -64,7 +64,7 @@ int main(int argc, const char* argv[]) { settings.delete_old = logging::DELETE_OLD_LOG_FILE; logging::InitLogging(settings); - icu_util::Initialize(); + base::i18n::InitializeICU(); base::FilePath input_file; PathService::Get(base::DIR_SOURCE_ROOT, &input_file); diff --git a/chromium/net/tools/tld_cleanup/tld_cleanup_util.cc b/chromium/net/tools/tld_cleanup/tld_cleanup_util.cc index dfa26206f98..8e04b55bd1c 100644 --- a/chromium/net/tools/tld_cleanup/tld_cleanup_util.cc +++ b/chromium/net/tools/tld_cleanup/tld_cleanup_util.cc @@ -233,7 +233,7 @@ NormalizeResult NormalizeFile(const base::FilePath& in_filename, const base::FilePath& out_filename) { RuleMap rules; std::string data; - if (!file_util::ReadFileToString(in_filename, &data)) { + if (!base::ReadFileToString(in_filename, &data)) { LOG(ERROR) << "Unable to read file"; // We return success since we've already reported the error. return kSuccess; diff --git a/chromium/net/udp/udp_socket_libevent.cc b/chromium/net/udp/udp_socket_libevent.cc index 90c7da65041..5ed52f57090 100644 --- a/chromium/net/udp/udp_socket_libevent.cc +++ b/chromium/net/udp/udp_socket_libevent.cc @@ -21,6 +21,7 @@ #include "net/base/net_errors.h" #include "net/base/net_log.h" #include "net/base/net_util.h" +#include "net/socket/socket_descriptor.h" #include "net/udp/udp_net_log_parameters.h" namespace { @@ -381,7 +382,7 @@ void UDPSocketLibevent::LogRead(int result, int UDPSocketLibevent::CreateSocket(const IPEndPoint& address) { addr_family_ = address.GetSockAddrFamily(); - socket_ = socket(addr_family_, SOCK_DGRAM, 0); + socket_ = CreatePlatformSocket(addr_family_, SOCK_DGRAM, 0); if (socket_ == kInvalidSocket) return MapSystemError(errno); if (SetNonBlocking(socket_)) { diff --git a/chromium/net/udp/udp_socket_libevent.h b/chromium/net/udp/udp_socket_libevent.h index 8f68a9b57a9..6c8bf61e0dd 100644 --- a/chromium/net/udp/udp_socket_libevent.h +++ b/chromium/net/udp/udp_socket_libevent.h @@ -15,6 +15,7 @@ #include "net/base/net_export.h" #include "net/base/net_log.h" #include "net/base/rand_callback.h" +#include "net/socket/socket_descriptor.h" #include "net/udp/datagram_socket.h" namespace net { @@ -152,8 +153,6 @@ class NET_EXPORT UDPSocketLibevent : public base::NonThreadSafe { int SetMulticastLoopbackMode(bool loopback); private: - static const int kInvalidSocket = -1; - enum SocketOptions { SOCKET_OPTION_REUSE_ADDRESS = 1 << 0, SOCKET_OPTION_BROADCAST = 1 << 1, diff --git a/chromium/net/udp/udp_socket_win.cc b/chromium/net/udp/udp_socket_win.cc index 1f0c337fd4a..f9ce6799c4b 100644 --- a/chromium/net/udp/udp_socket_win.cc +++ b/chromium/net/udp/udp_socket_win.cc @@ -20,6 +20,7 @@ #include "net/base/net_util.h" #include "net/base/winsock_init.h" #include "net/base/winsock_util.h" +#include "net/socket/socket_descriptor.h" #include "net/udp/udp_net_log_parameters.h" namespace { diff --git a/chromium/net/url_request/data_protocol_handler.cc b/chromium/net/url_request/data_protocol_handler.cc index 3222f725920..c6d9716e458 100644 --- a/chromium/net/url_request/data_protocol_handler.cc +++ b/chromium/net/url_request/data_protocol_handler.cc @@ -16,4 +16,8 @@ URLRequestJob* DataProtocolHandler::MaybeCreateJob( return new URLRequestDataJob(request, network_delegate); } +bool DataProtocolHandler::IsSafeRedirectTarget(const GURL& location) const { + return false; +} + } // namespace net diff --git a/chromium/net/url_request/data_protocol_handler.h b/chromium/net/url_request/data_protocol_handler.h index abb5abe2990..b7f7fefda0b 100644 --- a/chromium/net/url_request/data_protocol_handler.h +++ b/chromium/net/url_request/data_protocol_handler.h @@ -20,6 +20,7 @@ class NET_EXPORT DataProtocolHandler DataProtocolHandler(); virtual URLRequestJob* MaybeCreateJob( URLRequest* request, NetworkDelegate* network_delegate) const OVERRIDE; + virtual bool IsSafeRedirectTarget(const GURL& location) const OVERRIDE; private: DISALLOW_COPY_AND_ASSIGN(DataProtocolHandler); diff --git a/chromium/net/url_request/file_protocol_handler.cc b/chromium/net/url_request/file_protocol_handler.cc index dc5b16f1bbe..ef8096f8798 100644 --- a/chromium/net/url_request/file_protocol_handler.cc +++ b/chromium/net/url_request/file_protocol_handler.cc @@ -5,6 +5,7 @@ #include "net/url_request/file_protocol_handler.h" #include "base/logging.h" +#include "base/task_runner.h" #include "net/base/net_errors.h" #include "net/base/net_util.h" #include "net/url_request/url_request.h" @@ -14,7 +15,11 @@ namespace net { -FileProtocolHandler::FileProtocolHandler() { } +FileProtocolHandler::FileProtocolHandler( + const scoped_refptr<base::TaskRunner>& file_task_runner) + : file_task_runner_(file_task_runner) {} + +FileProtocolHandler::~FileProtocolHandler() {} URLRequestJob* FileProtocolHandler::MaybeCreateJob( URLRequest* request, NetworkDelegate* network_delegate) const { @@ -41,7 +46,8 @@ URLRequestJob* FileProtocolHandler::MaybeCreateJob( // Use a regular file request job for all non-directories (including invalid // file names). - return new URLRequestFileJob(request, network_delegate, file_path); + return new URLRequestFileJob(request, network_delegate, file_path, + file_task_runner_); } bool FileProtocolHandler::IsSafeRedirectTarget(const GURL& location) const { diff --git a/chromium/net/url_request/file_protocol_handler.h b/chromium/net/url_request/file_protocol_handler.h index 8087a6ee5a9..78956a93471 100644 --- a/chromium/net/url_request/file_protocol_handler.h +++ b/chromium/net/url_request/file_protocol_handler.h @@ -7,10 +7,15 @@ #include "base/basictypes.h" #include "base/compiler_specific.h" +#include "base/memory/ref_counted.h" #include "net/url_request/url_request_job_factory.h" class GURL; +namespace base { +class TaskRunner; +} + namespace net { class NetworkDelegate; @@ -21,12 +26,15 @@ class URLRequestJob; class NET_EXPORT FileProtocolHandler : public URLRequestJobFactory::ProtocolHandler { public: - FileProtocolHandler(); + explicit FileProtocolHandler( + const scoped_refptr<base::TaskRunner>& file_task_runner); + virtual ~FileProtocolHandler(); virtual URLRequestJob* MaybeCreateJob( URLRequest* request, NetworkDelegate* network_delegate) const OVERRIDE; virtual bool IsSafeRedirectTarget(const GURL& location) const OVERRIDE; private: + const scoped_refptr<base::TaskRunner> file_task_runner_; DISALLOW_COPY_AND_ASSIGN(FileProtocolHandler); }; diff --git a/chromium/net/url_request/test_url_fetcher_factory.cc b/chromium/net/url_request/test_url_fetcher_factory.cc index e0394c44d18..30380354b2d 100644 --- a/chromium/net/url_request/test_url_fetcher_factory.cc +++ b/chromium/net/url_request/test_url_fetcher_factory.cc @@ -357,11 +357,18 @@ URLFetcher* FakeURLFetcherFactory::CreateURLFetcher( return fake_fetcher.release(); } +void FakeURLFetcherFactory::SetFakeResponseForURL( + const GURL& url, + const std::string& response_data, + bool success) { + // Overwrite existing URL if it already exists. + fake_responses_[url] = std::make_pair(response_data, success); +} + void FakeURLFetcherFactory::SetFakeResponse(const std::string& url, const std::string& response_data, bool success) { - // Overwrite existing URL if it already exists. - fake_responses_[GURL(url)] = std::make_pair(response_data, success); + SetFakeResponseForURL(GURL(url), response_data, success); } void FakeURLFetcherFactory::ClearFakeResponses() { diff --git a/chromium/net/url_request/test_url_fetcher_factory.h b/chromium/net/url_request/test_url_fetcher_factory.h index 89f74c2f590..35b4607f995 100644 --- a/chromium/net/url_request/test_url_fetcher_factory.h +++ b/chromium/net/url_request/test_url_fetcher_factory.h @@ -377,6 +377,12 @@ class FakeURLFetcherFactory : public URLFetcherFactory, // Sets the fake response for a given URL. If success is true we will serve // an HTTP/200 and an HTTP/500 otherwise. The |response_data| may be empty. + void SetFakeResponseForURL(const GURL& url, + const std::string& response_data, + bool success); + + // Convenience helper that calls SetFakeResponseForURL with GURL(url). + // TODO(mnissler): Convert callers to SetFakeResponseForURL. void SetFakeResponse(const std::string& url, const std::string& response_data, bool success); diff --git a/chromium/net/url_request/url_fetcher_core.cc b/chromium/net/url_request/url_fetcher_core.cc index 8f0e28dcd93..f8d3773b7d0 100644 --- a/chromium/net/url_request/url_fetcher_core.cc +++ b/chromium/net/url_request/url_fetcher_core.cc @@ -566,8 +566,6 @@ void URLFetcherCore::StartURLRequest() { request_->set_method( request_type_ == URLFetcher::POST ? "POST" : request_type_ == URLFetcher::PUT ? "PUT" : "PATCH"); - extra_request_headers_.SetHeader(HttpRequestHeaders::kContentType, - upload_content_type_); if (!upload_content_type_.empty()) { extra_request_headers_.SetHeader(HttpRequestHeaders::kContentType, upload_content_type_); diff --git a/chromium/net/url_request/url_fetcher_impl_unittest.cc b/chromium/net/url_request/url_fetcher_impl_unittest.cc index 7cc1674b181..62f627258db 100644 --- a/chromium/net/url_request/url_fetcher_impl_unittest.cc +++ b/chromium/net/url_request/url_fetcher_impl_unittest.cc @@ -555,7 +555,7 @@ void URLFetcherPostFileTest::CreateFetcher(const GURL& url) { void URLFetcherPostFileTest::OnURLFetchComplete(const URLFetcher* source) { std::string expected; - ASSERT_TRUE(file_util::ReadFileToString(path_, &expected)); + ASSERT_TRUE(base::ReadFileToString(path_, &expected)); ASSERT_LE(range_offset_, expected.size()); uint64 expected_size = std::min(range_length_, expected.size() - range_offset_); diff --git a/chromium/net/url_request/url_fetcher_response_writer.cc b/chromium/net/url_request/url_fetcher_response_writer.cc index cb30dad8089..a7d5b5a8a24 100644 --- a/chromium/net/url_request/url_fetcher_response_writer.cc +++ b/chromium/net/url_request/url_fetcher_response_writer.cc @@ -102,8 +102,19 @@ int URLFetcherFileWriter::Write(IOBuffer* buffer, } int URLFetcherFileWriter::Finish(const CompletionCallback& callback) { + int result = file_stream_->Close(base::Bind( + &URLFetcherFileWriter::CloseComplete, + weak_factory_.GetWeakPtr(), callback)); + if (result != ERR_IO_PENDING) + file_stream_.reset(); + return result; +} + +void URLFetcherFileWriter::CloseComplete(const CompletionCallback& callback, + int result) { + // Destroy |file_stream_| whether or not the close succeeded. file_stream_.reset(); - return OK; + callback.Run(result); } void URLFetcherFileWriter::DidWrite(const CompletionCallback& callback, diff --git a/chromium/net/url_request/url_fetcher_response_writer.h b/chromium/net/url_request/url_fetcher_response_writer.h index 3b7fe9eb546..2f15831873d 100644 --- a/chromium/net/url_request/url_fetcher_response_writer.h +++ b/chromium/net/url_request/url_fetcher_response_writer.h @@ -112,6 +112,9 @@ class URLFetcherFileWriter : public URLFetcherResponseWriter { void DidOpenFile(const CompletionCallback& callback, int result); + // Callback which gets the result of closing a file. + void CloseComplete(const CompletionCallback& callback, int result); + // The last error encountered on a file operation. OK if no error occurred. int error_code_; diff --git a/chromium/net/url_request/url_request.cc b/chromium/net/url_request/url_request.cc index ce361c69838..bf031c8bd72 100644 --- a/chromium/net/url_request/url_request.cc +++ b/chromium/net/url_request/url_request.cc @@ -453,7 +453,7 @@ void URLRequest::GetCharset(string* charset) { job_->GetCharset(charset); } -int URLRequest::GetResponseCode() { +int URLRequest::GetResponseCode() const { DCHECK(job_.get()); return job_->GetResponseCode(); } diff --git a/chromium/net/url_request/url_request.h b/chromium/net/url_request/url_request.h index 03978d6b1d2..a01656c97bd 100644 --- a/chromium/net/url_request/url_request.h +++ b/chromium/net/url_request/url_request.h @@ -528,7 +528,7 @@ class NET_EXPORT URLRequest : NON_EXPORTED_BASE(public base::NonThreadSafe), // Returns the HTTP response code (e.g., 200, 404, and so on). This method // may only be called once the delegate's OnResponseStarted method has been // called. For non-HTTP requests, this method returns -1. - int GetResponseCode(); + int GetResponseCode() const; // Get the HTTP response info in its entirety. const HttpResponseInfo& response_info() const { return response_info_; } diff --git a/chromium/net/url_request/url_request_context_builder.cc b/chromium/net/url_request/url_request_context_builder.cc index 540dfc1082b..dce8214dc85 100644 --- a/chromium/net/url_request/url_request_context_builder.cc +++ b/chromium/net/url_request/url_request_context_builder.cc @@ -156,6 +156,11 @@ class BasicURLRequestContext : public URLRequestContext { return file_thread_.message_loop(); } + scoped_refptr<base::MessageLoopProxy> file_message_loop_proxy() { + DCHECK(file_thread_.IsRunning()); + return file_thread_.message_loop_proxy(); + } + protected: virtual ~BasicURLRequestContext() {} @@ -190,7 +195,9 @@ URLRequestContextBuilder::URLRequestContextBuilder() #if !defined(DISABLE_FTP_SUPPORT) ftp_enabled_(false), #endif - http_cache_enabled_(true) {} + http_cache_enabled_(true) { +} + URLRequestContextBuilder::~URLRequestContextBuilder() {} #if defined(OS_LINUX) || defined(OS_ANDROID) @@ -301,7 +308,8 @@ URLRequestContext* URLRequestContextBuilder::Build() { if (data_enabled_) job_factory->SetProtocolHandler("data", new DataProtocolHandler); if (file_enabled_) - job_factory->SetProtocolHandler("file", new FileProtocolHandler); + job_factory->SetProtocolHandler( + "file", new FileProtocolHandler(context->file_message_loop_proxy())); #if !defined(DISABLE_FTP_SUPPORT) if (ftp_enabled_) { ftp_transaction_factory_.reset( diff --git a/chromium/net/url_request/url_request_file_job.cc b/chromium/net/url_request/url_request_file_job.cc index 437f962393e..053a22ed233 100644 --- a/chromium/net/url_request/url_request_file_job.cc +++ b/chromium/net/url_request/url_request_file_job.cc @@ -26,8 +26,8 @@ #include "base/platform_file.h" #include "base/strings/string_util.h" #include "base/synchronization/lock.h" +#include "base/task_runner.h" #include "base/threading/thread_restrictions.h" -#include "base/threading/worker_pool.h" #include "build/build_config.h" #include "net/base/file_stream.h" #include "net/base/io_buffer.h" @@ -53,26 +53,27 @@ URLRequestFileJob::FileMetaInfo::FileMetaInfo() is_directory(false) { } -URLRequestFileJob::URLRequestFileJob(URLRequest* request, - NetworkDelegate* network_delegate, - const base::FilePath& file_path) +URLRequestFileJob::URLRequestFileJob( + URLRequest* request, + NetworkDelegate* network_delegate, + const base::FilePath& file_path, + const scoped_refptr<base::TaskRunner>& file_task_runner) : URLRequestJob(request, network_delegate), file_path_(file_path), - stream_(new FileStream(NULL)), + stream_(new FileStream(NULL, file_task_runner)), + file_task_runner_(file_task_runner), remaining_bytes_(0), - weak_ptr_factory_(this) { -} + weak_ptr_factory_(this) {} void URLRequestFileJob::Start() { FileMetaInfo* meta_info = new FileMetaInfo(); - base::WorkerPool::PostTaskAndReply( + file_task_runner_->PostTaskAndReply( FROM_HERE, base::Bind(&URLRequestFileJob::FetchMetaInfo, file_path_, base::Unretained(meta_info)), base::Bind(&URLRequestFileJob::DidFetchMetaInfo, weak_ptr_factory_.GetWeakPtr(), - base::Owned(meta_info)), - true); + base::Owned(meta_info))); } void URLRequestFileJob::Kill() { diff --git a/chromium/net/url_request/url_request_file_job.h b/chromium/net/url_request/url_request_file_job.h index 6fc7fb91980..7cd7df0ba3b 100644 --- a/chromium/net/url_request/url_request_file_job.h +++ b/chromium/net/url_request/url_request_file_job.h @@ -9,6 +9,7 @@ #include <vector> #include "base/files/file_path.h" +#include "base/memory/ref_counted.h" #include "base/memory/weak_ptr.h" #include "net/base/net_export.h" #include "net/http/http_byte_range.h" @@ -17,6 +18,7 @@ namespace base{ struct PlatformFileInfo; +class TaskRunner; } namespace file_util { struct FileInfo; @@ -31,7 +33,8 @@ class NET_EXPORT URLRequestFileJob : public URLRequestJob { public: URLRequestFileJob(URLRequest* request, NetworkDelegate* network_delegate, - const base::FilePath& file_path); + const base::FilePath& file_path, + const scoped_refptr<base::TaskRunner>& file_task_runner); // URLRequestJob: virtual void Start() OVERRIDE; @@ -91,6 +94,7 @@ class NET_EXPORT URLRequestFileJob : public URLRequestJob { scoped_ptr<FileStream> stream_; FileMetaInfo meta_info_; + const scoped_refptr<base::TaskRunner> file_task_runner_; HttpByteRange byte_range_; int64 remaining_bytes_; diff --git a/chromium/net/url_request/url_request_job.cc b/chromium/net/url_request/url_request_job.cc index 669c845de98..bf4aafcb8c6 100644 --- a/chromium/net/url_request/url_request_job.cc +++ b/chromium/net/url_request/url_request_job.cc @@ -319,6 +319,10 @@ void URLRequestJob::NotifyHeadersComplete() { new_location = new_location.ReplaceComponents(replacements); } + // Redirect response bodies are not read. Notify the transaction + // so it does not treat being stopped as an error. + DoneReading(); + bool defer_redirect = false; request_->NotifyReceivedRedirect(new_location, &defer_redirect); diff --git a/chromium/net/url_request/url_request_job_unittest.cc b/chromium/net/url_request/url_request_job_unittest.cc index 354915fe12b..5f63b0927de 100644 --- a/chromium/net/url_request/url_request_job_unittest.cc +++ b/chromium/net/url_request/url_request_job_unittest.cc @@ -37,6 +37,24 @@ const MockTransaction kGZip_Transaction = { net::OK }; +const MockTransaction kRedirect_Transaction = { + "http://www.google.com/redirect", + "GET", + base::Time(), + "", + net::LOAD_NORMAL, + "HTTP/1.1 302 Found", + "Cache-Control: max-age=10000\n" + "Location: http://www.google.com/destination\n" + "Content-Length: 5\n", + base::Time(), + "hello", + TEST_MODE_NORMAL, + NULL, + 0, + net::OK +}; + } // namespace TEST(URLRequestJob, TransactionNotifiedWhenDone) { @@ -78,3 +96,22 @@ TEST(URLRequestJob, SyncTransactionNotifiedWhenDone) { RemoveMockTransaction(&transaction); } + +TEST(URLRequestJob, RedirectTransactionNotifiedWhenDone) { + MockNetworkLayer network_layer; + net::TestURLRequestContext context; + context.set_http_transaction_factory(&network_layer); + + net::TestDelegate d; + net::TestURLRequest req(GURL(kRedirect_Transaction.url), &d, &context, NULL); + AddMockTransaction(&kRedirect_Transaction); + + req.set_method("GET"); + req.Start(); + + base::MessageLoop::current()->Run(); + + EXPECT_TRUE(network_layer.done_reading_called()); + + RemoveMockTransaction(&kRedirect_Transaction); +} diff --git a/chromium/net/url_request/url_request_unittest.cc b/chromium/net/url_request/url_request_unittest.cc index 64f5ac1d477..d18de134416 100644 --- a/chromium/net/url_request/url_request_unittest.cc +++ b/chromium/net/url_request/url_request_unittest.cc @@ -19,7 +19,9 @@ #include "base/format_macros.h" #include "base/memory/weak_ptr.h" #include "base/message_loop/message_loop.h" +#include "base/message_loop/message_loop_proxy.h" #include "base/path_service.h" +#include "base/run_loop.h" #include "base/strings/string_number_conversions.h" #include "base/strings/string_piece.h" #include "base/strings/string_split.h" @@ -591,11 +593,15 @@ class URLRequestTest : public PlatformTest { default_context_.set_network_delegate(&default_network_delegate_); default_context_.set_net_log(&net_log_); job_factory_.SetProtocolHandler("data", new DataProtocolHandler); - job_factory_.SetProtocolHandler("file", new FileProtocolHandler); + job_factory_.SetProtocolHandler( + "file", new FileProtocolHandler(base::MessageLoopProxy::current())); default_context_.set_job_factory(&job_factory_); default_context_.Init(); } - virtual ~URLRequestTest() {} + virtual ~URLRequestTest() { + // URLRequestJobs may post clean-up tasks on destruction. + base::RunLoop().RunUntilIdle(); + } // Adds the TestJobInterceptor to the default context. TestJobInterceptor* AddTestInterceptor() { @@ -620,7 +626,7 @@ TEST_F(URLRequestTest, AboutBlankTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(!r.is_pending()); EXPECT_FALSE(d.received_data_before_response()); @@ -663,7 +669,7 @@ TEST_F(URLRequestTest, DataURLImageTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(!r.is_pending()); EXPECT_FALSE(d.received_data_before_response()); @@ -688,7 +694,7 @@ TEST_F(URLRequestTest, FileTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = -1; EXPECT_TRUE(file_util::GetFileSize(app_path, &file_size)); @@ -720,7 +726,7 @@ TEST_F(URLRequestTest, FileTestCancel) { } // Async cancellation should be safe even when URLRequest has been already // destroyed. - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); } TEST_F(URLRequestTest, FileTestFullSpecifiedRange) { @@ -755,7 +761,7 @@ TEST_F(URLRequestTest, FileTestFullSpecifiedRange) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(!r.is_pending()); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -798,7 +804,7 @@ TEST_F(URLRequestTest, FileTestHalfSpecifiedRange) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(!r.is_pending()); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -834,7 +840,7 @@ TEST_F(URLRequestTest, FileTestMultipleRanges) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.request_failed()); } @@ -849,7 +855,7 @@ TEST_F(URLRequestTest, InvalidUrlTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.request_failed()); } } @@ -887,7 +893,7 @@ TEST_F(URLRequestTest, ResolveShortcutTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); WIN32_FILE_ATTRIBUTE_DATA data; GetFileAttributesEx(app_path.value().c_str(), @@ -931,7 +937,7 @@ TEST_F(URLRequestTest, FileDirCancelTest) { d.set_cancel_in_received_data_pending(true); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); } // Take out mock resource provider. @@ -952,7 +958,7 @@ TEST_F(URLRequestTest, FileDirRedirectNoCrash) { TestDelegate d; URLRequest req(FilePathToFileURL(path), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); ASSERT_EQ(1, d.received_redirect_count()); ASSERT_LT(0, d.bytes_received()); @@ -966,7 +972,7 @@ TEST_F(URLRequestTest, FileDirRedirectSingleSlash) { TestDelegate d; URLRequest req(GURL("file:///"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); ASSERT_EQ(1, d.received_redirect_count()); ASSERT_FALSE(req.status().is_success()); @@ -1194,7 +1200,7 @@ TEST_F(URLRequestTest, Intercept) { req.SetUserData(&user_data2, user_data2); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Make sure we can retrieve our specific user data EXPECT_EQ(user_data0, req.GetUserData(NULL)); @@ -1229,7 +1235,7 @@ TEST_F(URLRequestTest, InterceptRedirect) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_intercept_main_); @@ -1262,7 +1268,7 @@ TEST_F(URLRequestTest, InterceptServerError) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_intercept_main_); @@ -1291,7 +1297,7 @@ TEST_F(URLRequestTest, InterceptNetworkError) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_simulate_error_main_); @@ -1320,7 +1326,7 @@ TEST_F(URLRequestTest, InterceptRestartRequired) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_restart_main_); @@ -1351,7 +1357,7 @@ TEST_F(URLRequestTest, InterceptRespectsCancelMain) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_cancel_main_); @@ -1382,7 +1388,7 @@ TEST_F(URLRequestTest, InterceptRespectsCancelRedirect) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_intercept_main_); @@ -1407,7 +1413,7 @@ TEST_F(URLRequestTest, InterceptRespectsCancelFinal) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_simulate_error_main_); @@ -1433,7 +1439,7 @@ TEST_F(URLRequestTest, InterceptRespectsCancelInRestart) { URLRequest req(GURL("http://test_intercept/foo"), &d, &default_context_); req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check the interceptor got called as expected EXPECT_TRUE(interceptor.did_cancel_then_restart_main_); @@ -1452,7 +1458,7 @@ LoadTimingInfo RunLoadTimingTest(const LoadTimingInfo& job_load_timing, TestDelegate d; URLRequest req(GURL("http://test_intercept/foo"), &d, context); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); LoadTimingInfo resulting_load_timing; req.GetLoadTimingInfo(&resulting_load_timing); @@ -1729,7 +1735,7 @@ TEST_F(URLRequestTest, NetworkDelegateProxyError) { req.set_method("GET"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // Check we see a failed request. EXPECT_FALSE(req.status().is_success()); @@ -1747,7 +1753,7 @@ TEST_F(URLRequestTest, RequestCompletionForEmptyResponse) { TestDelegate d; URLRequest req(GURL("data:,"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("", d.data_received()); EXPECT_EQ(1, default_network_delegate_.completed_requests()); } @@ -1841,7 +1847,7 @@ TEST_F(URLRequestTest, DelayedCookieCallback) { URLRequest req( test_server.GetURL("set-cookie?CookieToNotSend=1"), &d, &context); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); EXPECT_EQ(1, network_delegate.set_cookie_count()); @@ -1854,7 +1860,7 @@ TEST_F(URLRequestTest, DelayedCookieCallback) { TestDelegate d; URLRequest req(test_server.GetURL("echoheader?Cookie"), &d, &context); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSend=1") != std::string::npos); @@ -1876,7 +1882,7 @@ TEST_F(URLRequestTest, DoNotSendCookies) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); } @@ -1889,7 +1895,7 @@ TEST_F(URLRequestTest, DoNotSendCookies) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSend=1") != std::string::npos); @@ -1906,7 +1912,7 @@ TEST_F(URLRequestTest, DoNotSendCookies) { test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.set_load_flags(LOAD_DO_NOT_SEND_COOKIES); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("Cookie: CookieToNotSend=1") == std::string::npos); @@ -1930,7 +1936,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); @@ -1949,7 +1955,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies) { req.set_load_flags(LOAD_DO_NOT_SAVE_COOKIES); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // LOAD_DO_NOT_SAVE_COOKIES does not trigger OnSetCookie. EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); @@ -1965,7 +1971,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSave=1") == std::string::npos); @@ -1991,7 +1997,7 @@ TEST_F(URLRequestTest, DoNotSendCookies_ViaPolicy) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); @@ -2005,7 +2011,7 @@ TEST_F(URLRequestTest, DoNotSendCookies_ViaPolicy) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSend=1") != std::string::npos); @@ -2023,7 +2029,7 @@ TEST_F(URLRequestTest, DoNotSendCookies_ViaPolicy) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("Cookie: CookieToNotSend=1") == std::string::npos); @@ -2046,7 +2052,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies_ViaPolicy) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); @@ -2064,7 +2070,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies_ViaPolicy) { &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(2, network_delegate.blocked_set_cookie_count()); @@ -2078,7 +2084,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies_ViaPolicy) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSave=1") == std::string::npos); @@ -2101,7 +2107,7 @@ TEST_F(URLRequestTest, DoNotSaveEmptyCookies) { TestDelegate d; URLRequest req(test_server.GetURL("set-cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); @@ -2122,7 +2128,7 @@ TEST_F(URLRequestTest, DoNotSendCookies_ViaPolicy_Async) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); @@ -2136,7 +2142,7 @@ TEST_F(URLRequestTest, DoNotSendCookies_ViaPolicy_Async) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSend=1") != std::string::npos); @@ -2154,7 +2160,7 @@ TEST_F(URLRequestTest, DoNotSendCookies_ViaPolicy_Async) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("Cookie: CookieToNotSend=1") == std::string::npos); @@ -2177,7 +2183,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies_ViaPolicy_Async) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(0, network_delegate.blocked_set_cookie_count()); @@ -2195,7 +2201,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies_ViaPolicy_Async) { &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(0, network_delegate.blocked_get_cookies_count()); EXPECT_EQ(2, network_delegate.blocked_set_cookie_count()); @@ -2209,7 +2215,7 @@ TEST_F(URLRequestTest, DoNotSaveCookies_ViaPolicy_Async) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("CookieToNotSave=1") == std::string::npos); @@ -2278,7 +2284,7 @@ TEST_F(URLRequestTest, AcceptClockSkewCookieWithWrongDateTimezone) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); } // Verify that the cookie is not set. { @@ -2288,7 +2294,7 @@ TEST_F(URLRequestTest, AcceptClockSkewCookieWithWrongDateTimezone) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("StillGood=1") == std::string::npos); } @@ -2302,7 +2308,7 @@ TEST_F(URLRequestTest, AcceptClockSkewCookieWithWrongDateTimezone) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); } // Verify that the cookie is set. { @@ -2312,7 +2318,7 @@ TEST_F(URLRequestTest, AcceptClockSkewCookieWithWrongDateTimezone) { URLRequest req( test_server.GetURL("echoheader?Cookie"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("StillGood=1") != std::string::npos); } @@ -2338,7 +2344,7 @@ TEST_F(URLRequestTest, DoNotOverrideReferrer) { req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("http://foo.com/", d.data_received()); } @@ -2356,7 +2362,7 @@ TEST_F(URLRequestTest, DoNotOverrideReferrer) { req.set_load_flags(LOAD_VALIDATE_CACHE); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("None", d.data_received()); } @@ -2393,7 +2399,7 @@ class URLRequestTestHTTP : public URLRequestTest { req.SetExtraRequestHeaders(headers); } req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(redirect_method, req.method()); EXPECT_EQ(URLRequestStatus::SUCCESS, req.status().status()); EXPECT_EQ(OK, req.status().error()); @@ -2436,7 +2442,7 @@ class URLRequestTestHTTP : public URLRequestTest { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); ASSERT_EQ(1, d.response_started_count()) << "request failed: " << r.status().status() @@ -2482,7 +2488,7 @@ class URLRequestTestHTTP : public URLRequestTest { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); bool is_success = r.status().is_success(); @@ -2521,7 +2527,7 @@ TEST_F(URLRequestTestHTTP, ProxyTunnelRedirectTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::FAILED, r.status().status()); EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, r.status().error()); @@ -2547,7 +2553,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateTunnelConnectionFailed) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::FAILED, r.status().status()); EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, r.status().error()); @@ -2588,12 +2594,12 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateBlockAsynchronously) { r.Start(); for (size_t i = 0; i < blocking_stages_length; ++i) { - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(blocking_stages[i], network_delegate.stage_blocked_for_callback()); network_delegate.DoCallback(OK); } - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(200, r.GetResponseCode()); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(1, network_delegate.created_requests()); @@ -2620,7 +2626,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateCancelRequest) { URLRequest r(test_server_.GetURL(std::string()), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::FAILED, r.status().status()); EXPECT_EQ(ERR_EMPTY_RESPONSE, r.status().error()); @@ -2650,7 +2656,7 @@ void NetworkDelegateCancelRequest(BlockingNetworkDelegate::BlockMode block_mode, URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::FAILED, r.status().status()); EXPECT_EQ(ERR_BLOCKED_BY_CLIENT, r.status().error()); @@ -2727,7 +2733,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateRedirectRequest) { URLRequest r(original_url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2760,7 +2766,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateRedirectRequestSynchronously) { URLRequest r(original_url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2800,7 +2806,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateRedirectRequestPost) { base::UintToString(arraysize(kData) - 1)); r.SetExtraRequestHeaders(headers); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2837,7 +2843,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateOnAuthRequiredSyncNoAction) { URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2874,7 +2880,7 @@ TEST_F(URLRequestTestHTTP, EXPECT_FALSE(headers.HasHeader("Authorization")); } - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2908,7 +2914,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateOnAuthRequiredSyncSetAuth) { GURL url(test_server_.GetURL("auth-basic")); URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2943,7 +2949,7 @@ TEST_F(URLRequestTestHTTP, GURL url(test_server_.GetURL("auth-basic")); URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -2981,7 +2987,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateOnAuthRequiredSyncCancel) { GURL url(test_server_.GetURL("auth-basic")); URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(OK, r.status().error()); @@ -3015,7 +3021,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateOnAuthRequiredAsyncNoAction) { GURL url(test_server_.GetURL("auth-basic")); URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -3050,7 +3056,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateOnAuthRequiredAsyncSetAuth) { GURL url(test_server_.GetURL("auth-basic")); URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(0, r.status().error()); @@ -3083,7 +3089,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateOnAuthRequiredAsyncCancel) { GURL url(test_server_.GetURL("auth-basic")); URLRequest r(url, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); EXPECT_EQ(OK, r.status().error()); @@ -3114,7 +3120,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateCancelWhileWaiting1) { URLRequest r(test_server_.GetURL(std::string()), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(BlockingNetworkDelegate::ON_BEFORE_URL_REQUEST, network_delegate.stage_blocked_for_callback()); EXPECT_EQ(0, network_delegate.completed_requests()); @@ -3150,7 +3156,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateCancelWhileWaiting2) { URLRequest r(test_server_.GetURL(std::string()), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(BlockingNetworkDelegate::ON_BEFORE_SEND_HEADERS, network_delegate.stage_blocked_for_callback()); EXPECT_EQ(0, network_delegate.completed_requests()); @@ -3185,7 +3191,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateCancelWhileWaiting3) { URLRequest r(test_server_.GetURL(std::string()), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(BlockingNetworkDelegate::ON_HEADERS_RECEIVED, network_delegate.stage_blocked_for_callback()); EXPECT_EQ(0, network_delegate.completed_requests()); @@ -3220,7 +3226,7 @@ TEST_F(URLRequestTestHTTP, NetworkDelegateCancelWhileWaiting4) { URLRequest r(test_server_.GetURL("auth-basic"), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(BlockingNetworkDelegate::ON_AUTH_REQUIRED, network_delegate.stage_blocked_for_callback()); EXPECT_EQ(0, network_delegate.completed_requests()); @@ -3254,7 +3260,7 @@ TEST_F(URLRequestTestHTTP, UnexpectedServerAuthTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::FAILED, r.status().status()); EXPECT_EQ(ERR_TUNNEL_CONNECTION_FAILED, r.status().error()); @@ -3271,7 +3277,7 @@ TEST_F(URLRequestTestHTTP, GetTest_NoCache) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -3337,7 +3343,7 @@ TEST_F(URLRequestTestHTTP, GetTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -3363,7 +3369,7 @@ TEST_F(URLRequestTestHTTP, GetTest_GetFullRequestHeaders) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -3388,7 +3394,7 @@ TEST_F(URLRequestTestHTTP, GetTestLoadTiming) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); LoadTimingInfo load_timing_info; r.GetLoadTimingInfo(&load_timing_info); @@ -3437,7 +3443,7 @@ TEST_F(URLRequestTestHTTP, GetZippedTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -3473,7 +3479,7 @@ TEST_F(URLRequestTestHTTP, HTTPSToHTTPRedirectNoRefererTest) { "server-redirect?" + http_destination.spec()), &d, &default_context_); req.SetReferrer("https://www.referrer.com/"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_EQ(1, d.received_redirect_count()); @@ -3490,7 +3496,7 @@ TEST_F(URLRequestTestHTTP, RedirectLoadTiming) { TestDelegate d; URLRequest req(original_url, &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_EQ(1, d.received_redirect_count()); @@ -3530,7 +3536,7 @@ TEST_F(URLRequestTestHTTP, MultipleRedirectTest) { TestDelegate d; URLRequest req(original_url, &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_EQ(2, d.received_redirect_count()); @@ -3568,7 +3574,7 @@ TEST_F(URLRequestTestHTTP, RedirectWithAdditionalHeadersTest) { RedirectWithAdditionalHeadersDelegate d; URLRequest req(original_url, &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); std::string value; const HttpRequestHeaders& headers = req.extra_request_headers(); @@ -3605,7 +3611,7 @@ TEST_F(URLRequestTestHTTP, RedirectWithHeaderRemovalTest) { URLRequest req(original_url, &d, &default_context_); req.SetExtraRequestHeaderByName(kExtraHeaderToRemove, "dummy", false); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); std::string value; const HttpRequestHeaders& headers = req.extra_request_headers(); @@ -3625,7 +3631,7 @@ TEST_F(URLRequestTestHTTP, CancelTest) { r.Cancel(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // We expect to receive OnResponseStarted even though the request has been // cancelled. @@ -3647,7 +3653,7 @@ TEST_F(URLRequestTestHTTP, CancelTest2) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_EQ(0, d.bytes_received()); @@ -3668,7 +3674,7 @@ TEST_F(URLRequestTestHTTP, CancelTest3) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); // There is no guarantee about how much data was received @@ -3713,7 +3719,7 @@ TEST_F(URLRequestTestHTTP, CancelTest5) { TestDelegate d; URLRequest r(test_server_.GetURL("cachetime"), &d, &default_context_); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::SUCCESS, r.status().status()); } @@ -3723,7 +3729,7 @@ TEST_F(URLRequestTestHTTP, CancelTest5) { URLRequest r(test_server_.GetURL("cachetime"), &d, &default_context_); r.Start(); r.Cancel(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::CANCELED, r.status().status()); EXPECT_EQ(1, d.response_started_count()); @@ -3753,7 +3759,7 @@ TEST_F(URLRequestTestHTTP, PostEmptyTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); ASSERT_EQ(1, d.response_started_count()) << "request failed: " << r.status().status() @@ -3804,7 +3810,7 @@ TEST_F(URLRequestTestHTTP, PostFileTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 size = 0; ASSERT_EQ(true, file_util::GetFileSize(path, &size)); @@ -3835,7 +3841,7 @@ TEST_F(URLRequestTestHTTP, TestPostChunkedDataBeforeStart) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); VerifyReceivedDataMatchesChunks(&r, &d); } @@ -3852,7 +3858,7 @@ TEST_F(URLRequestTestHTTP, TestPostChunkedDataJustAfterStart) { r.Start(); EXPECT_TRUE(r.is_pending()); AddChunksToUpload(&r); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); VerifyReceivedDataMatchesChunks(&r, &d); } @@ -3869,9 +3875,9 @@ TEST_F(URLRequestTestHTTP, TestPostChunkedDataAfterStart) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->RunUntilIdle(); + base::RunLoop().RunUntilIdle(); AddChunksToUpload(&r); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); VerifyReceivedDataMatchesChunks(&r, &d); } @@ -3884,7 +3890,7 @@ TEST_F(URLRequestTestHTTP, ResponseHeadersTest) { URLRequest req( test_server_.GetURL("files/with-headers.html"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); const HttpResponseHeaders* headers = req.response_headers(); @@ -3920,7 +3926,7 @@ TEST_F(URLRequestTestHTTP, ProcessSTS) { &d, &default_context_); request.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); TransportSecurityState* security_state = default_context_.transport_security_state(); @@ -3964,7 +3970,7 @@ TEST_F(URLRequestTestHTTP, MAYBE_ProcessPKP) { &d, &default_context_); request.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); TransportSecurityState* security_state = default_context_.transport_security_state(); @@ -3995,7 +4001,7 @@ TEST_F(URLRequestTestHTTP, ProcessSTSOnce) { &d, &default_context_); request.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // We should have set parameters from the first header, not the second. TransportSecurityState* security_state = @@ -4024,7 +4030,7 @@ TEST_F(URLRequestTestHTTP, ProcessSTSAndPKP) { &d, &default_context_); request.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // We should have set parameters from the first header, not the second. TransportSecurityState* security_state = @@ -4066,7 +4072,7 @@ TEST_F(URLRequestTestHTTP, ProcessSTSAndPKP2) { &d, &default_context_); request.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); TransportSecurityState* security_state = default_context_.transport_security_state(); @@ -4095,7 +4101,7 @@ TEST_F(URLRequestTestHTTP, ContentTypeNormalizationTest) { URLRequest req(test_server_.GetURL( "files/content-type-normalization.html"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); std::string mime_type; req.GetMimeType(&mime_type); @@ -4111,23 +4117,36 @@ TEST_F(URLRequestTestHTTP, ProtocolHandlerAndFactoryRestrictRedirects) { // Test URLRequestJobFactory::ProtocolHandler::IsSafeRedirectTarget(). GURL file_url("file:///foo.txt"); GURL data_url("data:,foo"); - FileProtocolHandler file_protocol_handler; + FileProtocolHandler file_protocol_handler(base::MessageLoopProxy::current()); EXPECT_FALSE(file_protocol_handler.IsSafeRedirectTarget(file_url)); DataProtocolHandler data_protocol_handler; - EXPECT_TRUE(data_protocol_handler.IsSafeRedirectTarget(data_url)); + EXPECT_FALSE(data_protocol_handler.IsSafeRedirectTarget(data_url)); // Test URLRequestJobFactoryImpl::IsSafeRedirectTarget(). EXPECT_FALSE(job_factory_.IsSafeRedirectTarget(file_url)); - EXPECT_TRUE(job_factory_.IsSafeRedirectTarget(data_url)); + EXPECT_FALSE(job_factory_.IsSafeRedirectTarget(data_url)); } -TEST_F(URLRequestTestHTTP, RestrictRedirects) { +TEST_F(URLRequestTestHTTP, RestrictFileRedirects) { ASSERT_TRUE(test_server_.Start()); TestDelegate d; URLRequest req(test_server_.GetURL( "files/redirect-to-file.html"), &d, &default_context_); req.Start(); + base::RunLoop().Run(); + + EXPECT_EQ(URLRequestStatus::FAILED, req.status().status()); + EXPECT_EQ(ERR_UNSAFE_REDIRECT, req.status().error()); +} + +TEST_F(URLRequestTestHTTP, RestrictDataRedirects) { + ASSERT_TRUE(test_server_.Start()); + + TestDelegate d; + URLRequest req(test_server_.GetURL( + "files/redirect-to-data.html"), &d, &default_context_); + req.Start(); base::MessageLoop::current()->Run(); EXPECT_EQ(URLRequestStatus::FAILED, req.status().status()); @@ -4141,7 +4160,7 @@ TEST_F(URLRequestTestHTTP, RedirectToInvalidURL) { URLRequest req(test_server_.GetURL( "files/redirect-to-invalid-url.html"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(URLRequestStatus::FAILED, req.status().status()); EXPECT_EQ(ERR_INVALID_URL, req.status().error()); @@ -4155,7 +4174,7 @@ TEST_F(URLRequestTestHTTP, NoUserPassInReferrer) { test_server_.GetURL("echoheader?Referer"), &d, &default_context_); req.SetReferrer("http://user:pass@foo.com/"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(std::string("http://foo.com/"), d.data_received()); } @@ -4168,7 +4187,7 @@ TEST_F(URLRequestTestHTTP, NoFragmentInReferrer) { test_server_.GetURL("echoheader?Referer"), &d, &default_context_); req.SetReferrer("http://foo.com/test#fragment"); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(std::string("http://foo.com/test"), d.data_received()); } @@ -4182,7 +4201,7 @@ TEST_F(URLRequestTestHTTP, EmptyReferrerAfterValidReferrer) { req.SetReferrer("http://foo.com/test#fragment"); req.SetReferrer(""); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(std::string("None"), d.data_received()); } @@ -4196,7 +4215,7 @@ TEST_F(URLRequestTestHTTP, CancelRedirect) { URLRequest req( test_server_.GetURL("files/redirect-test.html"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_EQ(0, d.bytes_received()); @@ -4215,12 +4234,12 @@ TEST_F(URLRequestTestHTTP, DeferredRedirect) { URLRequest req(test_url, &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.received_redirect_count()); req.FollowDeferredRedirect(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -4234,7 +4253,7 @@ TEST_F(URLRequestTestHTTP, DeferredRedirect) { path = path.Append(FILE_PATH_LITERAL("with-headers.html")); std::string contents; - EXPECT_TRUE(file_util::ReadFileToString(path, &contents)); + EXPECT_TRUE(base::ReadFileToString(path, &contents)); EXPECT_EQ(contents, d.data_received()); } } @@ -4251,7 +4270,7 @@ TEST_F(URLRequestTestHTTP, DeferredRedirect_GetFullRequestHeaders) { EXPECT_FALSE(d.have_full_request_headers()); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.received_redirect_count()); EXPECT_TRUE(d.have_full_request_headers()); @@ -4259,7 +4278,7 @@ TEST_F(URLRequestTestHTTP, DeferredRedirect_GetFullRequestHeaders) { d.ClearFullRequestHeaders(); req.FollowDeferredRedirect(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); GURL target_url(test_server_.GetURL("files/with-headers.html")); EXPECT_EQ(1, d.response_started_count()); @@ -4276,7 +4295,7 @@ TEST_F(URLRequestTestHTTP, DeferredRedirect_GetFullRequestHeaders) { path = path.Append(FILE_PATH_LITERAL("with-headers.html")); std::string contents; - EXPECT_TRUE(file_util::ReadFileToString(path, &contents)); + EXPECT_TRUE(base::ReadFileToString(path, &contents)); EXPECT_EQ(contents, d.data_received()); } } @@ -4290,12 +4309,12 @@ TEST_F(URLRequestTestHTTP, CancelDeferredRedirect) { URLRequest req( test_server_.GetURL("files/redirect-test.html"), &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.received_redirect_count()); req.Cancel(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_EQ(0, d.bytes_received()); @@ -4316,7 +4335,7 @@ TEST_F(URLRequestTestHTTP, VaryHeader) { headers.SetHeader("foo", "1"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); LoadTimingInfo load_timing_info; req.GetLoadTimingInfo(&load_timing_info); @@ -4332,7 +4351,7 @@ TEST_F(URLRequestTestHTTP, VaryHeader) { headers.SetHeader("foo", "1"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(req.was_cached()); @@ -4350,7 +4369,7 @@ TEST_F(URLRequestTestHTTP, VaryHeader) { headers.SetHeader("foo", "2"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_FALSE(req.was_cached()); @@ -4371,7 +4390,7 @@ TEST_F(URLRequestTestHTTP, BasicAuth) { URLRequest r(test_server_.GetURL("auth-basic"), &d, &default_context_); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("user/secret") != std::string::npos); } @@ -4387,7 +4406,7 @@ TEST_F(URLRequestTestHTTP, BasicAuth) { r.set_load_flags(LOAD_VALIDATE_CACHE); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("user/secret") != std::string::npos); @@ -4418,7 +4437,7 @@ TEST_F(URLRequestTestHTTP, BasicAuthWithCookies) { URLRequest r(url_requiring_auth, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("user/secret") != std::string::npos); @@ -4447,7 +4466,7 @@ TEST_F(URLRequestTestHTTP, BasicAuthWithCookies) { URLRequest r(url_with_identity, &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("user2/secret") != std::string::npos); @@ -4469,7 +4488,7 @@ TEST_F(URLRequestTestHTTP, BasicAuthLoadTiming) { URLRequest r(test_server_.GetURL("auth-basic"), &d, &default_context_); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("user/secret") != std::string::npos); @@ -4501,7 +4520,7 @@ TEST_F(URLRequestTestHTTP, BasicAuthLoadTiming) { r.set_load_flags(LOAD_VALIDATE_CACHE); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(d.data_received().find("user/secret") != std::string::npos); @@ -4544,7 +4563,7 @@ TEST_F(URLRequestTestHTTP, Post302RedirectGet) { "Origin: http://localhost:1337/"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); std::string mime_type; req.GetMimeType(&mime_type); @@ -4627,7 +4646,7 @@ TEST_F(URLRequestTestHTTP, InterceptPost302RedirectGet) { AddTestInterceptor()->set_main_intercept_job(job); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("GET", req.method()); } @@ -4651,7 +4670,7 @@ TEST_F(URLRequestTestHTTP, InterceptPost307RedirectPost) { AddTestInterceptor()->set_main_intercept_job(job); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("POST", req.method()); EXPECT_EQ(kData, d.data_received()); } @@ -4671,7 +4690,7 @@ TEST_F(URLRequestTestHTTP, DefaultAcceptLanguage) { URLRequest req( test_server_.GetURL("echoheader?Accept-Language"), &d, &context); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("en", d.data_received()); } @@ -4692,7 +4711,7 @@ TEST_F(URLRequestTestHTTP, EmptyAcceptLanguage) { URLRequest req( test_server_.GetURL("echoheader?Accept-Language"), &d, &context); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("None", d.data_received()); } @@ -4709,7 +4728,7 @@ TEST_F(URLRequestTestHTTP, OverrideAcceptLanguage) { headers.SetHeader(HttpRequestHeaders::kAcceptLanguage, "ru"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(std::string("ru"), d.data_received()); } @@ -4724,7 +4743,7 @@ TEST_F(URLRequestTestHTTP, DefaultAcceptEncoding) { HttpRequestHeaders headers; req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_TRUE(ContainsString(d.data_received(), "gzip")); } @@ -4741,7 +4760,7 @@ TEST_F(URLRequestTestHTTP, OverrideAcceptEncoding) { headers.SetHeader(HttpRequestHeaders::kAcceptEncoding, "identity"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_FALSE(ContainsString(d.data_received(), "gzip")); EXPECT_TRUE(ContainsString(d.data_received(), "identity")); } @@ -4758,7 +4777,7 @@ TEST_F(URLRequestTestHTTP, SetAcceptCharset) { headers.SetHeader(HttpRequestHeaders::kAcceptCharset, "koi-8r"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(std::string("koi-8r"), d.data_received()); } @@ -4771,7 +4790,7 @@ TEST_F(URLRequestTestHTTP, DefaultUserAgent) { &d, &default_context_); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(req.context()->GetUserAgent(req.url()), d.data_received()); } @@ -4788,7 +4807,7 @@ TEST_F(URLRequestTestHTTP, OverrideUserAgent) { headers.SetHeader(HttpRequestHeaders::kUserAgent, "Lynx (textmode)"); req.SetExtraRequestHeaders(headers); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // If the net tests are being run with ChromeFrame then we need to allow for // the 'chromeframe' suffix which is added to the user agent before the // closing parentheses. @@ -4820,7 +4839,7 @@ TEST_F(URLRequestTestHTTP, EmptyHttpUserAgentSettings) { TestDelegate d; URLRequest req(test_server_.GetURL(tests[i].request), &d, &context); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(tests[i].expected_response, d.data_received()) << " Request = \"" << tests[i].request << "\""; } @@ -4850,7 +4869,7 @@ TEST_F(URLRequestTestHTTP, SetSubsequentJobPriority) { AddTestInterceptor()->set_main_intercept_job(job.get()); // Should trigger |job| to be started. - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(LOW, job->priority()); } @@ -4880,7 +4899,7 @@ TEST_F(HTTPSRequestTest, HTTPSGetTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -4912,7 +4931,7 @@ TEST_F(HTTPSRequestTest, HTTPSMismatchedTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -4948,7 +4967,7 @@ TEST_F(HTTPSRequestTest, HTTPSExpiredTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -4991,7 +5010,7 @@ TEST_F(HTTPSRequestTest, TLSv1Fallback) { URLRequest r(test_server.GetURL(std::string()), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_NE(0, d.bytes_received()); @@ -5035,7 +5054,7 @@ TEST_F(HTTPSRequestTest, HTTPSPreloadedHSTSTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -5083,7 +5102,7 @@ TEST_F(HTTPSRequestTest, HTTPSErrorsNoClobberTSSTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -5154,7 +5173,7 @@ TEST_F(HTTPSRequestTest, HSTSPreservesPosts) { req.set_upload(make_scoped_ptr(CreateSimpleUploadData(kData))); req.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ("https", req.url().scheme()); EXPECT_EQ("POST", req.method()); @@ -5184,7 +5203,7 @@ TEST_F(HTTPSRequestTest, SSLv3Fallback) { URLRequest r(test_server.GetURL(std::string()), &d, &context); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_NE(0, d.bytes_received()); @@ -5235,7 +5254,7 @@ TEST_F(HTTPSRequestTest, ClientAuthTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.on_certificate_requested_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -5246,7 +5265,7 @@ TEST_F(HTTPSRequestTest, ClientAuthTest) { // all platforms so we can test sending a cert as well. r.ContinueWithCertificate(NULL); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); EXPECT_FALSE(d.received_data_before_response()); @@ -5275,7 +5294,7 @@ TEST_F(HTTPSRequestTest, ResumeTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); } @@ -5291,7 +5310,7 @@ TEST_F(HTTPSRequestTest, ResumeTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // The response will look like; // insert abc @@ -5344,7 +5363,7 @@ TEST_F(HTTPSRequestTest, SSLSessionCacheShardTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); } @@ -5376,7 +5395,7 @@ TEST_F(HTTPSRequestTest, SSLSessionCacheShardTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); // The response will look like; // insert abc @@ -5492,7 +5511,7 @@ class HTTPSOCSPTest : public HTTPSRequestTest { URLRequest r(test_server.GetURL(std::string()), &d, &context_); r.Start(); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_EQ(1, d.response_started_count()); *out_cert_status = r.ssl_info().cert_status; @@ -5711,6 +5730,34 @@ TEST_F(HTTPSEVCRLSetTest, MissingCRLSetAndInvalidOCSP) { static_cast<bool>(cert_status & CERT_STATUS_REV_CHECKING_ENABLED)); } +TEST_F(HTTPSEVCRLSetTest, MissingCRLSetAndRevokedOCSP) { + if (!SystemSupportsOCSP()) { + LOG(WARNING) << "Skipping test because system doesn't support OCSP"; + return; + } + + SpawnedTestServer::SSLOptions ssl_options( + SpawnedTestServer::SSLOptions::CERT_AUTO); + ssl_options.ocsp_status = SpawnedTestServer::SSLOptions::OCSP_REVOKED; + SSLConfigService::SetCRLSet(scoped_refptr<CRLSet>()); + + CertStatus cert_status; + DoConnection(ssl_options, &cert_status); + + // Currently only works for Windows. When using NSS or OS X, it's not + // possible to determine whether the check failed because of actual + // revocation or because there was an OCSP failure. +#if defined(OS_WIN) + EXPECT_EQ(CERT_STATUS_REVOKED, cert_status & CERT_STATUS_ALL_ERRORS); +#else + EXPECT_EQ(0u, cert_status & CERT_STATUS_ALL_ERRORS); +#endif + + EXPECT_FALSE(cert_status & CERT_STATUS_IS_EV); + EXPECT_EQ(SystemUsesChromiumEVMetadata(), + static_cast<bool>(cert_status & CERT_STATUS_REV_CHECKING_ENABLED)); +} + TEST_F(HTTPSEVCRLSetTest, MissingCRLSetAndGoodOCSP) { if (!SystemSupportsOCSP()) { LOG(WARNING) << "Skipping test because system doesn't support OCSP"; @@ -5920,7 +5967,7 @@ TEST_F(URLRequestTestFTP, UnsafePort) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_FALSE(r.is_pending()); EXPECT_EQ(URLRequestStatus::FAILED, r.status().status()); @@ -5938,7 +5985,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPDirectoryListing) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); EXPECT_FALSE(r.is_pending()); EXPECT_EQ(1, d.response_started_count()); @@ -5964,7 +6011,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPGetTestAnonymous) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -5996,7 +6043,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPGetTest) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6034,7 +6081,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCheckWrongPassword) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6067,7 +6114,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCheckWrongPasswordRestart) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6097,7 +6144,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCheckWrongUser) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6130,7 +6177,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCheckWrongUserRestart) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6162,7 +6209,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCacheURLCredentials) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6180,7 +6227,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCacheURLCredentials) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6214,7 +6261,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCacheLoginBoxCredentials) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); @@ -6235,7 +6282,7 @@ TEST_F(URLRequestTestFTP, DISABLED_FTPCacheLoginBoxCredentials) { r.Start(); EXPECT_TRUE(r.is_pending()); - base::MessageLoop::current()->Run(); + base::RunLoop().Run(); int64 file_size = 0; file_util::GetFileSize(app_path, &file_size); diff --git a/chromium/net/url_request/view_cache_helper_unittest.cc b/chromium/net/url_request/view_cache_helper_unittest.cc index db26f70ca05..0fdd1938a16 100644 --- a/chromium/net/url_request/view_cache_helper_unittest.cc +++ b/chromium/net/url_request/view_cache_helper_unittest.cc @@ -34,7 +34,8 @@ TestURLRequestContext::TestURLRequestContext() set_http_transaction_factory(&cache_); } -void WriteHeaders(disk_cache::Entry* entry, int flags, const std::string data) { +void WriteHeaders(disk_cache::Entry* entry, int flags, + const std::string& data) { if (data.empty()) return; @@ -53,7 +54,7 @@ void WriteHeaders(disk_cache::Entry* entry, int flags, const std::string data) { ASSERT_EQ(len, cb.GetResult(rv)); } -void WriteData(disk_cache::Entry* entry, int index, const std::string data) { +void WriteData(disk_cache::Entry* entry, int index, const std::string& data) { if (data.empty()) return; @@ -66,9 +67,9 @@ void WriteData(disk_cache::Entry* entry, int index, const std::string data) { ASSERT_EQ(len, cb.GetResult(rv)); } -void WriteToEntry(disk_cache::Backend* cache, const std::string key, - const std::string data0, const std::string data1, - const std::string data2) { +void WriteToEntry(disk_cache::Backend* cache, const std::string& key, + const std::string& data0, const std::string& data1, + const std::string& data2) { net::TestCompletionCallback cb; disk_cache::Entry* entry; int rv = cache->CreateEntry(key, &entry, cb.callback()); diff --git a/chromium/net/websockets/README b/chromium/net/websockets/README index 558a4511968..1d1e1c3538f 100644 --- a/chromium/net/websockets/README +++ b/chromium/net/websockets/README @@ -12,42 +12,57 @@ https://docs.google.com/a/google.com/document/d/1_R6YjCIrm4kikJ3YeapcOU2Keqr3lVU websocket_handshake_handler.cc websocket_handshake_handler.h -websocket_handshake_handler_unittest.cc -websocket_handshake_handler_spdy2_unittest.cc -websocket_handshake_handler_spdy3_unittest.cc +websocket_handshake_handler_test.cc +websocket_handshake_handler_spdy_test.cc websocket_job.cc websocket_job.h -websocket_job_unittest.cc +websocket_job_test.cc websocket_net_log_params.cc websocket_net_log_params.h -websocket_net_log_params_unittest.cc +websocket_net_log_params_test.cc websocket_throttle.cc websocket_throttle.h -websocket_throttle_unittest.cc +websocket_throttle_test.cc The following files are part of the new implementation. The new implementation performs framing and implements protocol semantics in the browser process, and presents a high-level interface to the renderer process similar to a multiplexing proxy. This is not yet used in any stable Chromium version. +websocket_basic_stream.cc +websocket_basic_stream.h +websocket_basic_stream_test.cc websocket_channel.cc websocket_channel.h websocket_channel_test.cc +websocket_deflater.h +websocket_deflater.cc +websocket_deflater_test.cc websocket_errors.cc websocket_errors.h -websocket_errors_unittest.cc +websocket_extension.cc +websocket_extension.h +websocket_extension_parser.cc +websocket_extension_parser.h +websocket_extension_parser_test.cc +websocket_errors_test.cc websocket_event_interface.h websocket_frame.cc websocket_frame.h websocket_frame_parser.cc websocket_frame_parser.h -websocket_frame_parser_unittest.cc -websocket_frame_unittest.cc +websocket_frame_parser_test.cc +websocket_frame_test.cc websocket_mux.h websocket_stream_base.h websocket_stream.cc websocket_stream.h +These files are shared between the old and new implementations. + +websocket_handshake_constants.cc +websocket_handshake_constants.h + A pre-submit check helps us keep this README file up-to-date: PRESUBMIT.py diff --git a/chromium/net/websockets/websocket_basic_stream.cc b/chromium/net/websockets/websocket_basic_stream.cc new file mode 100644 index 00000000000..5b02b18cb61 --- /dev/null +++ b/chromium/net/websockets/websocket_basic_stream.cc @@ -0,0 +1,259 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_basic_stream.h" + +#include <algorithm> +#include <limits> +#include <string> +#include <vector> + +#include "base/basictypes.h" +#include "base/bind.h" +#include "base/logging.h" +#include "net/base/io_buffer.h" +#include "net/base/net_errors.h" +#include "net/socket/client_socket_handle.h" +#include "net/websockets/websocket_errors.h" +#include "net/websockets/websocket_frame.h" +#include "net/websockets/websocket_frame_parser.h" + +namespace net { + +namespace { + +// The number of bytes to attempt to read at a time. +// TODO(ricea): See if there is a better number or algorithm to fulfill our +// requirements: +// 1. We would like to use minimal memory on low-bandwidth or idle connections +// 2. We would like to read as close to line speed as possible on +// high-bandwidth connections +// 3. We can't afford to cause jank on the IO thread by copying large buffers +// around +// 4. We would like to hit any sweet-spots that might exist in terms of network +// packet sizes / encryption block sizes / IPC alignment issues, etc. +const int kReadBufferSize = 32 * 1024; + +} // namespace + +WebSocketBasicStream::WebSocketBasicStream( + scoped_ptr<ClientSocketHandle> connection) + : read_buffer_(new IOBufferWithSize(kReadBufferSize)), + connection_(connection.Pass()), + generate_websocket_masking_key_(&GenerateWebSocketMaskingKey) { + DCHECK(connection_->is_initialized()); +} + +WebSocketBasicStream::~WebSocketBasicStream() { Close(); } + +int WebSocketBasicStream::ReadFrames( + ScopedVector<WebSocketFrameChunk>* frame_chunks, + const CompletionCallback& callback) { + DCHECK(frame_chunks->empty()); + // If there is data left over after parsing the HTTP headers, attempt to parse + // it as WebSocket frames. + if (http_read_buffer_) { + DCHECK_GE(http_read_buffer_->offset(), 0); + // We cannot simply copy the data into read_buffer_, as it might be too + // large. + scoped_refptr<GrowableIOBuffer> buffered_data; + buffered_data.swap(http_read_buffer_); + DCHECK(http_read_buffer_.get() == NULL); + if (!parser_.Decode(buffered_data->StartOfBuffer(), + buffered_data->offset(), + frame_chunks)) + return WebSocketErrorToNetError(parser_.websocket_error()); + if (!frame_chunks->empty()) + return OK; + } + + // Run until socket stops giving us data or we get some chunks. + while (true) { + // base::Unretained(this) here is safe because net::Socket guarantees not to + // call any callbacks after Disconnect(), which we call from the + // destructor. The caller of ReadFrames() is required to keep |frame_chunks| + // valid. + int result = connection_->socket() + ->Read(read_buffer_.get(), + read_buffer_->size(), + base::Bind(&WebSocketBasicStream::OnReadComplete, + base::Unretained(this), + base::Unretained(frame_chunks), + callback)); + if (result == ERR_IO_PENDING) + return result; + result = HandleReadResult(result, frame_chunks); + if (result != ERR_IO_PENDING) + return result; + } +} + +int WebSocketBasicStream::WriteFrames( + ScopedVector<WebSocketFrameChunk>* frame_chunks, + const CompletionCallback& callback) { + // This function always concatenates all frames into a single buffer. + // TODO(ricea): Investigate whether it would be better in some cases to + // perform multiple writes with smaller buffers. + // + // First calculate the size of the buffer we need to allocate. + typedef ScopedVector<WebSocketFrameChunk>::const_iterator Iterator; + const int kMaximumTotalSize = std::numeric_limits<int>::max(); + int total_size = 0; + for (Iterator it = frame_chunks->begin(); it != frame_chunks->end(); ++it) { + WebSocketFrameChunk* chunk = *it; + DCHECK(chunk->header) + << "Only complete frames are supported by WebSocketBasicStream"; + DCHECK(chunk->final_chunk) + << "Only complete frames are supported by WebSocketBasicStream"; + // Force the masked bit on. + chunk->header->masked = true; + // We enforce flow control so the renderer should never be able to force us + // to cache anywhere near 2GB of frames. + int chunk_size = + chunk->data->size() + GetWebSocketFrameHeaderSize(*(chunk->header)); + CHECK_GE(kMaximumTotalSize - total_size, chunk_size) + << "Aborting to prevent overflow"; + total_size += chunk_size; + } + scoped_refptr<IOBufferWithSize> combined_buffer( + new IOBufferWithSize(total_size)); + char* dest = combined_buffer->data(); + int remaining_size = total_size; + for (Iterator it = frame_chunks->begin(); it != frame_chunks->end(); ++it) { + WebSocketFrameChunk* chunk = *it; + WebSocketMaskingKey mask = generate_websocket_masking_key_(); + int result = WriteWebSocketFrameHeader( + *(chunk->header), &mask, dest, remaining_size); + DCHECK(result != ERR_INVALID_ARGUMENT) + << "WriteWebSocketFrameHeader() says that " << remaining_size + << " is not enough to write the header in. This should not happen."; + CHECK_GE(result, 0) << "Potentially security-critical check failed"; + dest += result; + remaining_size -= result; + + const char* const frame_data = chunk->data->data(); + const int frame_size = chunk->data->size(); + CHECK_GE(remaining_size, frame_size); + std::copy(frame_data, frame_data + frame_size, dest); + MaskWebSocketFramePayload(mask, 0, dest, frame_size); + dest += frame_size; + remaining_size -= frame_size; + } + DCHECK_EQ(0, remaining_size) << "Buffer size calculation was wrong; " + << remaining_size << " bytes left over."; + scoped_refptr<DrainableIOBuffer> drainable_buffer( + new DrainableIOBuffer(combined_buffer, total_size)); + return WriteEverything(drainable_buffer, callback); +} + +void WebSocketBasicStream::Close() { connection_->socket()->Disconnect(); } + +std::string WebSocketBasicStream::GetSubProtocol() const { + return sub_protocol_; +} + +std::string WebSocketBasicStream::GetExtensions() const { return extensions_; } + +int WebSocketBasicStream::SendHandshakeRequest( + const GURL& url, + const HttpRequestHeaders& headers, + HttpResponseInfo* response_info, + const CompletionCallback& callback) { + // TODO(ricea): Implement handshake-related functionality. + NOTREACHED(); + return ERR_NOT_IMPLEMENTED; +} + +int WebSocketBasicStream::ReadHandshakeResponse( + const CompletionCallback& callback) { + NOTREACHED(); + return ERR_NOT_IMPLEMENTED; +} + +/*static*/ +scoped_ptr<WebSocketBasicStream> +WebSocketBasicStream::CreateWebSocketBasicStreamForTesting( + scoped_ptr<ClientSocketHandle> connection, + const scoped_refptr<GrowableIOBuffer>& http_read_buffer, + const std::string& sub_protocol, + const std::string& extensions, + WebSocketMaskingKeyGeneratorFunction key_generator_function) { + scoped_ptr<WebSocketBasicStream> stream( + new WebSocketBasicStream(connection.Pass())); + if (http_read_buffer) { + stream->http_read_buffer_ = http_read_buffer; + } + stream->sub_protocol_ = sub_protocol; + stream->extensions_ = extensions; + stream->generate_websocket_masking_key_ = key_generator_function; + return stream.Pass(); +} + +int WebSocketBasicStream::WriteEverything( + const scoped_refptr<DrainableIOBuffer>& buffer, + const CompletionCallback& callback) { + while (buffer->BytesRemaining() > 0) { + // The use of base::Unretained() here is safe because on destruction we + // disconnect the socket, preventing any further callbacks. + int result = connection_->socket() + ->Write(buffer.get(), + buffer->BytesRemaining(), + base::Bind(&WebSocketBasicStream::OnWriteComplete, + base::Unretained(this), + buffer, + callback)); + if (result > 0) { + buffer->DidConsume(result); + } else { + return result; + } + } + return OK; +} + +void WebSocketBasicStream::OnWriteComplete( + const scoped_refptr<DrainableIOBuffer>& buffer, + const CompletionCallback& callback, + int result) { + if (result < 0) { + DCHECK(result != ERR_IO_PENDING); + callback.Run(result); + return; + } + + DCHECK(result != 0); + buffer->DidConsume(result); + result = WriteEverything(buffer, callback); + if (result != ERR_IO_PENDING) + callback.Run(result); +} + +int WebSocketBasicStream::HandleReadResult( + int result, + ScopedVector<WebSocketFrameChunk>* frame_chunks) { + DCHECK_NE(ERR_IO_PENDING, result); + DCHECK(frame_chunks->empty()); + if (result < 0) + return result; + if (result == 0) + return ERR_CONNECTION_CLOSED; + if (!parser_.Decode(read_buffer_->data(), result, frame_chunks)) + return WebSocketErrorToNetError(parser_.websocket_error()); + if (!frame_chunks->empty()) + return OK; + return ERR_IO_PENDING; +} + +void WebSocketBasicStream::OnReadComplete( + ScopedVector<WebSocketFrameChunk>* frame_chunks, + const CompletionCallback& callback, + int result) { + result = HandleReadResult(result, frame_chunks); + if (result == ERR_IO_PENDING) + result = ReadFrames(frame_chunks, callback); + if (result != ERR_IO_PENDING) + callback.Run(result); +} + +} // namespace net diff --git a/chromium/net/websockets/websocket_basic_stream.h b/chromium/net/websockets/websocket_basic_stream.h new file mode 100644 index 00000000000..1231da8142b --- /dev/null +++ b/chromium/net/websockets/websocket_basic_stream.h @@ -0,0 +1,129 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_BASIC_STREAM_H_ +#define NET_WEBSOCKETS_WEBSOCKET_BASIC_STREAM_H_ + +#include <string> + +#include "base/callback.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "base/memory/scoped_vector.h" +#include "net/websockets/websocket_frame_parser.h" +#include "net/websockets/websocket_stream.h" + +namespace net { + +class ClientSocketHandle; +class DrainableIOBuffer; +class GrowableIOBuffer; +class HttpRequestHeaders; +class HttpResponseInfo; +class IOBufferWithSize; +struct WebSocketFrameChunk; + +// Implementation of WebSocketStream for non-multiplexed ws:// connections (or +// the physical side of a multiplexed ws:// connection). +class NET_EXPORT_PRIVATE WebSocketBasicStream : public WebSocketStream { + public: + typedef WebSocketMaskingKey (*WebSocketMaskingKeyGeneratorFunction)(); + + // This class should not normally be constructed directly; see + // WebSocketStream::CreateAndConnectStream. + explicit WebSocketBasicStream(scoped_ptr<ClientSocketHandle> connection); + + // The destructor has to make sure the connection is closed when we finish so + // that it does not get returned to the pool. + virtual ~WebSocketBasicStream(); + + // WebSocketStream implementation. + virtual int ReadFrames(ScopedVector<WebSocketFrameChunk>* frame_chunks, + const CompletionCallback& callback) OVERRIDE; + + virtual int WriteFrames(ScopedVector<WebSocketFrameChunk>* frame_chunks, + const CompletionCallback& callback) OVERRIDE; + + virtual void Close() OVERRIDE; + + virtual std::string GetSubProtocol() const OVERRIDE; + + virtual std::string GetExtensions() const OVERRIDE; + + // Writes WebSocket handshake request HTTP-style to the connection. Adds + // "Sec-WebSocket-Key" header; this should not be included in |headers|. + virtual int SendHandshakeRequest(const GURL& url, + const HttpRequestHeaders& headers, + HttpResponseInfo* response_info, + const CompletionCallback& callback) OVERRIDE; + + virtual int ReadHandshakeResponse( + const CompletionCallback& callback) OVERRIDE; + + //////////////////////////////////////////////////////////////////////////// + // Methods for testing only. + + static scoped_ptr<WebSocketBasicStream> CreateWebSocketBasicStreamForTesting( + scoped_ptr<ClientSocketHandle> connection, + const scoped_refptr<GrowableIOBuffer>& http_read_buffer, + const std::string& sub_protocol, + const std::string& extensions, + WebSocketMaskingKeyGeneratorFunction key_generator_function); + + private: + // Returns OK or calls |callback| when the |buffer| is fully drained or + // something has failed. + int WriteEverything(const scoped_refptr<DrainableIOBuffer>& buffer, + const CompletionCallback& callback); + + // Wraps the |callback| to continue writing until everything has been written. + void OnWriteComplete(const scoped_refptr<DrainableIOBuffer>& buffer, + const CompletionCallback& callback, + int result); + + // Attempts to parse the output of a read as WebSocket frames. On success, + // returns OK and places the frame(s) in frame_chunks. + int HandleReadResult(int result, + ScopedVector<WebSocketFrameChunk>* frame_chunks); + + // Called when a read completes. Parses the result and (unless no complete + // header has been received) calls |callback|. + void OnReadComplete(ScopedVector<WebSocketFrameChunk>* frame_chunks, + const CompletionCallback& callback, + int result); + + // Storage for pending reads. All active WebSockets spend all the time with a + // call to ReadFrames() pending, so there is no benefit in trying to share + // this between sockets. + scoped_refptr<IOBufferWithSize> read_buffer_; + + // The connection, wrapped in a ClientSocketHandle so that we can prevent it + // from being returned to the pool. + scoped_ptr<ClientSocketHandle> connection_; + + // Only used during handshake. Some data may be left in this buffer after the + // handshake, in which case it will be picked up during the first call to + // ReadFrames(). The type is GrowableIOBuffer for compatibility with + // net::HttpStreamParser, which is used to parse the handshake. + scoped_refptr<GrowableIOBuffer> http_read_buffer_; + + // This keeps the current parse state (including any incomplete headers) and + // parses frames. + WebSocketFrameParser parser_; + + // The negotated sub-protocol, or empty for none. + std::string sub_protocol_; + + // The extensions negotiated with the remote server. + std::string extensions_; + + // This can be overridden in tests to make the output deterministic. We don't + // use a Callback here because a function pointer is faster and good enough + // for our purposes. + WebSocketMaskingKeyGeneratorFunction generate_websocket_masking_key_; +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_BASIC_STREAM_H_ diff --git a/chromium/net/websockets/websocket_basic_stream_test.cc b/chromium/net/websockets/websocket_basic_stream_test.cc new file mode 100644 index 00000000000..ec2e51a6d0f --- /dev/null +++ b/chromium/net/websockets/websocket_basic_stream_test.cc @@ -0,0 +1,513 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Tests for WebSocketBasicStream. Note that we do not attempt to verify that +// frame parsing itself functions correctly, as that is covered by the +// WebSocketFrameParser tests. + +#include "net/websockets/websocket_basic_stream.h" + +#include "base/basictypes.h" +#include "base/port.h" +#include "net/base/capturing_net_log.h" +#include "net/base/test_completion_callback.h" +#include "net/socket/socket_test_util.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { +namespace { + +// TODO(ricea): Add tests for +// - Empty frames (data & control) +// - Non-NULL masking key +// - A frame larger than kReadBufferSize; + +const char kSampleFrame[] = "\x81\x06Sample"; +const size_t kSampleFrameSize = arraysize(kSampleFrame) - 1; +const char kPartialLargeFrame[] = + "\x81\x7F\x00\x00\x00\x00\x7F\xFF\xFF\xFF" + "chromiunum ad pasco per loca insanis pullum manducat frumenti"; +const size_t kPartialLargeFrameSize = arraysize(kPartialLargeFrame) - 1; +const size_t kLargeFrameHeaderSize = 10; +const size_t kLargeFrameDeclaredPayloadSize = 0x7FFFFFFF; +const char kMultipleFrames[] = "\x81\x01X\x81\x01Y\x81\x01Z"; +const size_t kMultipleFramesSize = arraysize(kMultipleFrames) - 1; +// This frame encodes a payload length of 7 in two bytes, which is always +// invalid. +const char kInvalidFrame[] = "\x81\x7E\x00\x07Invalid"; +const size_t kInvalidFrameSize = arraysize(kInvalidFrame) - 1; +const char kWriteFrame[] = "\x81\x85\x00\x00\x00\x00Write"; +const size_t kWriteFrameSize = arraysize(kWriteFrame) - 1; +const WebSocketMaskingKey kNulMaskingKey = {{'\0', '\0', '\0', '\0'}}; + +// Generates a ScopedVector<WebSocketFrameChunk> which will have a wire format +// matching kWriteFrame. +ScopedVector<WebSocketFrameChunk> GenerateWriteFrame() { + scoped_ptr<WebSocketFrameChunk> chunk(new WebSocketFrameChunk); + const size_t payload_size = + kWriteFrameSize - (WebSocketFrameHeader::kBaseHeaderSize + + WebSocketFrameHeader::kMaskingKeyLength); + chunk->data = new IOBufferWithSize(payload_size); + memcpy(chunk->data->data(), + kWriteFrame + kWriteFrameSize - payload_size, + payload_size); + chunk->final_chunk = true; + scoped_ptr<WebSocketFrameHeader> header( + new WebSocketFrameHeader(WebSocketFrameHeader::kOpCodeText)); + header->final = true; + header->masked = true; + header->payload_length = payload_size; + chunk->header = header.Pass(); + ScopedVector<WebSocketFrameChunk> chunks; + chunks.push_back(chunk.release()); + return chunks.Pass(); +} + +// A masking key generator function which generates the identity mask, +// ie. "\0\0\0\0". +WebSocketMaskingKey GenerateNulMaskingKey() { return kNulMaskingKey; } + +// Base class for WebSocketBasicStream test fixtures. +class WebSocketBasicStreamTest : public ::testing::Test { + protected: + scoped_ptr<WebSocketBasicStream> stream_; + CapturingNetLog net_log_; +}; + +// A fixture for tests which only perform normal socket operations. +class WebSocketBasicStreamSocketTest : public WebSocketBasicStreamTest { + protected: + WebSocketBasicStreamSocketTest() + : histograms_("a"), pool_(1, 1, &histograms_, &factory_) {} + + virtual ~WebSocketBasicStreamSocketTest() { + // stream_ has a reference to socket_data_ (via MockTCPClientSocket) and so + // should be destroyed first. + stream_.reset(); + } + + scoped_ptr<ClientSocketHandle> MakeTransportSocket(MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count) { + socket_data_.reset( + new StaticSocketDataProvider(reads, reads_count, writes, writes_count)); + socket_data_->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + factory_.AddSocketDataProvider(socket_data_.get()); + + scoped_ptr<ClientSocketHandle> transport_socket(new ClientSocketHandle); + scoped_refptr<MockTransportSocketParams> params; + transport_socket->Init("a", + params, + MEDIUM, + CompletionCallback(), + &pool_, + bound_net_log_.bound()); + return transport_socket.Pass(); + } + + void SetHttpReadBuffer(const char* data, size_t size) { + http_read_buffer_ = new GrowableIOBuffer; + http_read_buffer_->SetCapacity(size); + memcpy(http_read_buffer_->data(), data, size); + http_read_buffer_->set_offset(size); + } + + void CreateStream(MockRead reads[], + size_t reads_count, + MockWrite writes[], + size_t writes_count) { + stream_ = WebSocketBasicStream::CreateWebSocketBasicStreamForTesting( + MakeTransportSocket(reads, reads_count, writes, writes_count), + http_read_buffer_, + sub_protocol_, + extensions_, + &GenerateNulMaskingKey); + } + + template <size_t N> + void CreateReadOnly(MockRead (&reads)[N]) { + CreateStream(reads, N, NULL, 0); + } + + template <size_t N> + void CreateWriteOnly(MockWrite (&writes)[N]) { + CreateStream(NULL, 0, writes, N); + } + + void CreateNullStream() { CreateStream(NULL, 0, NULL, 0); } + + scoped_ptr<SocketDataProvider> socket_data_; + MockClientSocketFactory factory_; + ClientSocketPoolHistograms histograms_; + MockTransportClientSocketPool pool_; + CapturingBoundNetLog(bound_net_log_); + ScopedVector<WebSocketFrameChunk> frame_chunks_; + TestCompletionCallback cb_; + scoped_refptr<GrowableIOBuffer> http_read_buffer_; + std::string sub_protocol_; + std::string extensions_; +}; + +TEST_F(WebSocketBasicStreamSocketTest, ConstructionWorks) { + CreateNullStream(); +} + +TEST_F(WebSocketBasicStreamSocketTest, SyncReadWorks) { + MockRead reads[] = {MockRead(SYNCHRONOUS, kSampleFrame, kSampleFrameSize)}; + CreateReadOnly(reads); + int result = stream_->ReadFrames(&frame_chunks_, cb_.callback()); + EXPECT_EQ(OK, result); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(GG_UINT64_C(6), frame_chunks_[0]->header->payload_length); + EXPECT_TRUE(frame_chunks_[0]->header->final); + EXPECT_TRUE(frame_chunks_[0]->final_chunk); +} + +TEST_F(WebSocketBasicStreamSocketTest, AsyncReadWorks) { + MockRead reads[] = {MockRead(ASYNC, kSampleFrame, kSampleFrameSize)}; + CreateReadOnly(reads); + int result = stream_->ReadFrames(&frame_chunks_, cb_.callback()); + ASSERT_EQ(ERR_IO_PENDING, result); + EXPECT_EQ(OK, cb_.WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(GG_UINT64_C(6), frame_chunks_[0]->header->payload_length); + // Don't repeat all the tests from SyncReadWorks; just enough to be sure the + // frame was really read. +} + +// ReadFrames will not return a frame whose header has not been wholly received. +TEST_F(WebSocketBasicStreamSocketTest, HeaderFragmentedSync) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kSampleFrame, 1), + MockRead(SYNCHRONOUS, kSampleFrame + 1, kSampleFrameSize - 1)}; + CreateReadOnly(reads); + int result = stream_->ReadFrames(&frame_chunks_, cb_.callback()); + ASSERT_EQ(OK, result); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(GG_UINT64_C(6), frame_chunks_[0]->header->payload_length); +} + +// The same behaviour applies to asynchronous reads. +TEST_F(WebSocketBasicStreamSocketTest, HeaderFragmentedAsync) { + MockRead reads[] = {MockRead(ASYNC, kSampleFrame, 1), + MockRead(ASYNC, kSampleFrame + 1, kSampleFrameSize - 1)}; + CreateReadOnly(reads); + int result = stream_->ReadFrames(&frame_chunks_, cb_.callback()); + ASSERT_EQ(ERR_IO_PENDING, result); + EXPECT_EQ(OK, cb_.WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(GG_UINT64_C(6), frame_chunks_[0]->header->payload_length); +} + +// If it receives an incomplete header in a synchronous call, then has to wait +// for the rest of the frame, ReadFrames will return ERR_IO_PENDING. +TEST_F(WebSocketBasicStreamSocketTest, HeaderFragmentedSyncAsync) { + MockRead reads[] = {MockRead(SYNCHRONOUS, kSampleFrame, 1), + MockRead(ASYNC, kSampleFrame + 1, kSampleFrameSize - 1)}; + CreateReadOnly(reads); + int result = stream_->ReadFrames(&frame_chunks_, cb_.callback()); + ASSERT_EQ(ERR_IO_PENDING, result); + EXPECT_EQ(OK, cb_.WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(GG_UINT64_C(6), frame_chunks_[0]->header->payload_length); +} + +// An extended header should also return ERR_IO_PENDING if it is not completely +// received. +TEST_F(WebSocketBasicStreamSocketTest, FragmentedLargeHeader) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kPartialLargeFrame, kLargeFrameHeaderSize - 1), + MockRead(SYNCHRONOUS, ERR_IO_PENDING)}; + CreateReadOnly(reads); + EXPECT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); +} + +// A frame that does not arrive in a single read should arrive in chunks. +TEST_F(WebSocketBasicStreamSocketTest, LargeFrameFirstChunk) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kPartialLargeFrame, kPartialLargeFrameSize)}; + CreateReadOnly(reads); + EXPECT_EQ(OK, stream_->ReadFrames(&frame_chunks_, cb_.callback())); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(kLargeFrameDeclaredPayloadSize, + frame_chunks_[0]->header->payload_length); + EXPECT_TRUE(frame_chunks_[0]->header->final); + EXPECT_FALSE(frame_chunks_[0]->final_chunk); + EXPECT_EQ(kPartialLargeFrameSize - kLargeFrameHeaderSize, + static_cast<size_t>(frame_chunks_[0]->data->size())); +} + +// If only the header arrives, we should get a zero-byte chunk. +TEST_F(WebSocketBasicStreamSocketTest, HeaderOnlyChunk) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kPartialLargeFrame, kLargeFrameHeaderSize)}; + CreateReadOnly(reads); + EXPECT_EQ(OK, stream_->ReadFrames(&frame_chunks_, cb_.callback())); + ASSERT_EQ(1U, frame_chunks_.size()); + EXPECT_FALSE(frame_chunks_[0]->final_chunk); + EXPECT_TRUE(frame_chunks_[0]->data.get() == NULL); +} + +// The second and subsequent chunks of a frame have no header. +TEST_F(WebSocketBasicStreamSocketTest, LargeFrameTwoChunks) { + static const size_t kChunkSize = 16; + MockRead reads[] = { + MockRead(ASYNC, kPartialLargeFrame, kChunkSize), + MockRead(ASYNC, kPartialLargeFrame + kChunkSize, kChunkSize)}; + CreateReadOnly(reads); + TestCompletionCallback cb[2]; + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb[0].callback())); + EXPECT_EQ(OK, cb[0].WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->header); + + frame_chunks_.clear(); + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb[1].callback())); + EXPECT_EQ(OK, cb[1].WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_FALSE(frame_chunks_[0]->header); +} + +// Only the final chunk of a frame has final_chunk set. +TEST_F(WebSocketBasicStreamSocketTest, OnlyFinalChunkIsFinal) { + static const size_t kFirstChunkSize = 4; + MockRead reads[] = {MockRead(ASYNC, kSampleFrame, kFirstChunkSize), + MockRead(ASYNC, + kSampleFrame + kFirstChunkSize, + kSampleFrameSize - kFirstChunkSize)}; + CreateReadOnly(reads); + TestCompletionCallback cb[2]; + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb[0].callback())); + EXPECT_EQ(OK, cb[0].WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_FALSE(frame_chunks_[0]->final_chunk); + + frame_chunks_.clear(); + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb[1].callback())); + EXPECT_EQ(OK, cb[1].WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->final_chunk); +} + +// Multiple frames that arrive together should be parsed correctly. +TEST_F(WebSocketBasicStreamSocketTest, ThreeFramesTogether) { + MockRead reads[] = { + MockRead(SYNCHRONOUS, kMultipleFrames, kMultipleFramesSize)}; + CreateReadOnly(reads); + + ASSERT_EQ(OK, stream_->ReadFrames(&frame_chunks_, cb_.callback())); + ASSERT_EQ(3U, frame_chunks_.size()); + EXPECT_TRUE(frame_chunks_[0]->final_chunk); + EXPECT_TRUE(frame_chunks_[1]->final_chunk); + EXPECT_TRUE(frame_chunks_[2]->final_chunk); +} + +// ERR_CONNECTION_CLOSED must be returned on close. +TEST_F(WebSocketBasicStreamSocketTest, SyncClose) { + MockRead reads[] = {MockRead(SYNCHRONOUS, "", 0)}; + CreateReadOnly(reads); + + EXPECT_EQ(ERR_CONNECTION_CLOSED, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); +} + +TEST_F(WebSocketBasicStreamSocketTest, AsyncClose) { + MockRead reads[] = {MockRead(ASYNC, "", 0)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(ERR_CONNECTION_CLOSED, cb_.WaitForResult()); +} + +// The result should be the same if the socket returns +// ERR_CONNECTION_CLOSED. This is not expected to happen on an established +// connection; a Read of size 0 is the expected behaviour. The key point of this +// test is to confirm that ReadFrames() behaviour is identical in both cases. +TEST_F(WebSocketBasicStreamSocketTest, SyncCloseWithErr) { + MockRead reads[] = {MockRead(SYNCHRONOUS, ERR_CONNECTION_CLOSED)}; + CreateReadOnly(reads); + + EXPECT_EQ(ERR_CONNECTION_CLOSED, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); +} + +TEST_F(WebSocketBasicStreamSocketTest, AsyncCloseWithErr) { + MockRead reads[] = {MockRead(ASYNC, ERR_CONNECTION_CLOSED)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(ERR_CONNECTION_CLOSED, cb_.WaitForResult()); +} + +TEST_F(WebSocketBasicStreamSocketTest, SyncErrorsPassedThrough) { + // ERR_INSUFFICIENT_RESOURCES here represents an arbitrary error that + // WebSocketBasicStream gives no special handling to. + MockRead reads[] = {MockRead(SYNCHRONOUS, ERR_INSUFFICIENT_RESOURCES)}; + CreateReadOnly(reads); + + EXPECT_EQ(ERR_INSUFFICIENT_RESOURCES, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); +} + +TEST_F(WebSocketBasicStreamSocketTest, AsyncErrorsPassedThrough) { + MockRead reads[] = {MockRead(ASYNC, ERR_INSUFFICIENT_RESOURCES)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(ERR_INSUFFICIENT_RESOURCES, cb_.WaitForResult()); +} + +// If we get a frame followed by a close, we should receive them separately. +TEST_F(WebSocketBasicStreamSocketTest, CloseAfterFrame) { + MockRead reads[] = {MockRead(SYNCHRONOUS, kSampleFrame, kSampleFrameSize), + MockRead(SYNCHRONOUS, "", 0)}; + CreateReadOnly(reads); + + EXPECT_EQ(OK, stream_->ReadFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(1U, frame_chunks_.size()); + frame_chunks_.clear(); + EXPECT_EQ(ERR_CONNECTION_CLOSED, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); +} + +// Synchronous close after an async frame header is handled by a different code +// path. +TEST_F(WebSocketBasicStreamSocketTest, AsyncCloseAfterIncompleteHeader) { + MockRead reads[] = {MockRead(ASYNC, kSampleFrame, 1U), + MockRead(SYNCHRONOUS, "", 0)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + ASSERT_EQ(ERR_CONNECTION_CLOSED, cb_.WaitForResult()); +} + +// When Stream::Read returns ERR_CONNECTION_CLOSED we get the same result via a +// slightly different code path. +TEST_F(WebSocketBasicStreamSocketTest, AsyncErrCloseAfterIncompleteHeader) { + MockRead reads[] = {MockRead(ASYNC, kSampleFrame, 1U), + MockRead(SYNCHRONOUS, ERR_CONNECTION_CLOSED)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + ASSERT_EQ(ERR_CONNECTION_CLOSED, cb_.WaitForResult()); +} + +// If there was a frame read at the same time as the response headers (and the +// handshake succeeded), then we should parse it. +TEST_F(WebSocketBasicStreamSocketTest, HttpReadBufferIsUsed) { + SetHttpReadBuffer(kSampleFrame, kSampleFrameSize); + CreateNullStream(); + + EXPECT_EQ(OK, stream_->ReadFrames(&frame_chunks_, cb_.callback())); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->data); + EXPECT_EQ(6, frame_chunks_[0]->data->size()); +} + +// Check that a frame whose header partially arrived at the end of the response +// headers works correctly. +TEST_F(WebSocketBasicStreamSocketTest, PartialFrameHeaderInHttpResponse) { + SetHttpReadBuffer(kSampleFrame, 1); + MockRead reads[] = {MockRead(ASYNC, kSampleFrame + 1, kSampleFrameSize - 1)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(OK, cb_.WaitForResult()); + ASSERT_EQ(1U, frame_chunks_.size()); + ASSERT_TRUE(frame_chunks_[0]->data); + EXPECT_EQ(6, frame_chunks_[0]->data->size()); + ASSERT_TRUE(frame_chunks_[0]->header); + EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, + frame_chunks_[0]->header->opcode); +} + +// Check that an invalid frame results in an error. +TEST_F(WebSocketBasicStreamSocketTest, SyncInvalidFrame) { + MockRead reads[] = {MockRead(SYNCHRONOUS, kInvalidFrame, kInvalidFrameSize)}; + CreateReadOnly(reads); + + EXPECT_EQ(ERR_WS_PROTOCOL_ERROR, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); +} + +TEST_F(WebSocketBasicStreamSocketTest, AsyncInvalidFrame) { + MockRead reads[] = {MockRead(ASYNC, kInvalidFrame, kInvalidFrameSize)}; + CreateReadOnly(reads); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->ReadFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(ERR_WS_PROTOCOL_ERROR, cb_.WaitForResult()); +} + +// Check that writing a frame all at once works. +TEST_F(WebSocketBasicStreamSocketTest, WriteAtOnce) { + MockWrite writes[] = {MockWrite(SYNCHRONOUS, kWriteFrame, kWriteFrameSize)}; + CreateWriteOnly(writes); + frame_chunks_ = GenerateWriteFrame(); + + EXPECT_EQ(OK, stream_->WriteFrames(&frame_chunks_, cb_.callback())); +} + +// Check that completely async writing works. +TEST_F(WebSocketBasicStreamSocketTest, AsyncWriteAtOnce) { + MockWrite writes[] = {MockWrite(ASYNC, kWriteFrame, kWriteFrameSize)}; + CreateWriteOnly(writes); + frame_chunks_ = GenerateWriteFrame(); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->WriteFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(OK, cb_.WaitForResult()); +} + +// Check that writing a frame to an extremely full kernel buffer (so that it +// ends up being sent in bits) works. The WriteFrames() callback should not be +// called until all parts have been written. +TEST_F(WebSocketBasicStreamSocketTest, WriteInBits) { + MockWrite writes[] = {MockWrite(SYNCHRONOUS, kWriteFrame, 4), + MockWrite(ASYNC, kWriteFrame + 4, 4), + MockWrite(ASYNC, kWriteFrame + 8, kWriteFrameSize - 8)}; + CreateWriteOnly(writes); + frame_chunks_ = GenerateWriteFrame(); + + ASSERT_EQ(ERR_IO_PENDING, + stream_->WriteFrames(&frame_chunks_, cb_.callback())); + EXPECT_EQ(OK, cb_.WaitForResult()); +} + +TEST_F(WebSocketBasicStreamSocketTest, GetExtensionsWorks) { + extensions_ = "inflate-uuencode"; + CreateNullStream(); + + EXPECT_EQ("inflate-uuencode", stream_->GetExtensions()); +} + +TEST_F(WebSocketBasicStreamSocketTest, GetSubProtocolWorks) { + sub_protocol_ = "cyberchat"; + CreateNullStream(); + + EXPECT_EQ("cyberchat", stream_->GetSubProtocol()); +} + +} // namespace +} // namespace net diff --git a/chromium/net/websockets/websocket_channel.cc b/chromium/net/websockets/websocket_channel.cc index fd845f92c3b..3db457c08f4 100644 --- a/chromium/net/websockets/websocket_channel.cc +++ b/chromium/net/websockets/websocket_channel.cc @@ -31,19 +31,6 @@ const size_t kWebSocketCloseCodeLength = 2; // WebSocketFrameHeader::payload_length in websocket_frame.h. const uint64 kMaxControlFramePayload = 125; -// Concatenate the data from two IOBufferWithSize objects into a single one. -IOBufferWithSize* ConcatenateIOBuffers( - const scoped_refptr<IOBufferWithSize>& part1, - const scoped_refptr<IOBufferWithSize>& part2) { - int newsize = part1->size() + part2->size(); - IOBufferWithSize* newbuffer = new IOBufferWithSize(newsize); - std::copy(part1->data(), part1->data() + part1->size(), newbuffer->data()); - std::copy(part2->data(), - part2->data() + part2->size(), - newbuffer->data() + part1->size()); - return newbuffer; -} - } // namespace // A class to encapsulate a set of frames and information about the size of @@ -90,10 +77,10 @@ class WebSocketChannel::ConnectDelegate } private: - // A pointer to the WebSocketChannel that created us. We do not need to worry - // about this pointer being stale, because deleting WebSocketChannel cancels - // the connect process, deleting this object and preventing its callbacks from - // being called. + // A pointer to the WebSocketChannel that created this object. There is no + // danger of this pointer being stale, because deleting the WebSocketChannel + // cancels the connect process, deleting this object and preventing its + // callbacks from being called. WebSocketChannel* const creator_; DISALLOW_COPY_AND_ASSIGN(ConnectDelegate); @@ -129,8 +116,8 @@ void WebSocketChannel::SendAddChannelRequest( } bool WebSocketChannel::InClosingState() const { - // We intentionally do not support state RECV_CLOSED here, because it is only - // used in one code path and should not leak into the code in general. + // The state RECV_CLOSED is not supported here, because it is only used in one + // code path and should not leak into the code in general. DCHECK_NE(RECV_CLOSED, state_) << "InClosingState called with state_ == RECV_CLOSED"; return state_ == SEND_CLOSED || state_ == CLOSE_WAIT || state_ == CLOSED; @@ -172,9 +159,9 @@ void WebSocketChannel::SendFrame(bool fin, } current_send_quota_ -= data.size(); // TODO(ricea): If current_send_quota_ has dropped below - // send_quota_low_water_mark_, we may want to consider increasing the "low - // water mark" and "high water mark", but only if we think we are not - // saturating the link to the WebSocket server. + // send_quota_low_water_mark_, it might be good to increase the "low + // water mark" and "high water mark", but only if the link to the WebSocket + // server is not saturated. // TODO(ricea): For kOpCodeText, do UTF-8 validation? scoped_refptr<IOBufferWithSize> buffer(new IOBufferWithSize(data.size())); std::copy(data.begin(), data.end(), buffer->data()); @@ -241,7 +228,7 @@ void WebSocketChannel::OnConnectSuccess(scoped_ptr<WebSocketStream> stream) { current_send_quota_ = send_quota_high_water_mark_; event_interface_->OnFlowControl(send_quota_high_water_mark_); - // We don't need this any more. + // |stream_request_| is not used once the connection has succeeded. stream_request_.reset(); ReadFrames(); } @@ -256,8 +243,8 @@ void WebSocketChannel::OnConnectFailure(uint16 websocket_error) { void WebSocketChannel::WriteFrames() { int result = OK; do { - // This use of base::Unretained is safe because we own the WebSocketStream - // and destroying it cancels all callbacks. + // This use of base::Unretained is safe because this object owns the + // WebSocketStream and destroying it cancels all callbacks. result = stream_->WriteFrames( data_being_sent_->frames(), base::Bind( @@ -317,8 +304,9 @@ void WebSocketChannel::OnWriteDone(bool synchronous, int result) { void WebSocketChannel::ReadFrames() { int result = OK; do { - // This use of base::Unretained is safe because we own the WebSocketStream, - // and any pending reads will be cancelled when it is destroyed. + // This use of base::Unretained is safe because this object owns the + // WebSocketStream, and any pending reads will be cancelled when it is + // destroyed. result = stream_->ReadFrames( &read_frame_chunks_, base::Bind( @@ -345,7 +333,7 @@ void WebSocketChannel::OnReadDone(bool synchronous, int result) { ProcessFrameChunk(chunk.Pass()); } read_frame_chunks_.clear(); - // We need to always keep a call to ReadFrames pending. + // There should always be a call to ReadFrames pending. if (!synchronous && state_ != CLOSED) { ReadFrames(); } @@ -388,10 +376,11 @@ void WebSocketChannel::ProcessFrameChunk( } } if (!current_frame_header_) { - // If we rejected the previous chunk as invalid, then we will have reset - // current_frame_header_ to avoid using it. More chunks of the invalid frame - // may still arrive, so this is not necessarily a bug on our side. However, - // if this happens when state_ is CONNECTED, it is definitely a bug. + // If this channel rejected the previous chunk as invalid, then it will have + // reset |current_frame_header_| and closed the channel. More chunks of the + // invalid frame may still arrive, and it is not necessarily a bug for that + // to happen. However, if this happens when state_ is CONNECTED, it is + // definitely a bug. DCHECK(state_ != CONNECTED) << "Unexpected header-less frame received " << "(final_chunk = " << chunk->final_chunk << ", data size = " << chunk->data->size() @@ -402,7 +391,7 @@ void WebSocketChannel::ProcessFrameChunk( data_buffer.swap(chunk->data); const bool is_final_chunk = chunk->final_chunk; chunk.reset(); - WebSocketFrameHeader::OpCode opcode = current_frame_header_->opcode; + const WebSocketFrameHeader::OpCode opcode = current_frame_header_->opcode; if (WebSocketFrameHeader::IsKnownControlOpCode(opcode)) { if (!current_frame_header_->final) { FailChannel(SEND_REAL_ERROR, @@ -419,32 +408,35 @@ void WebSocketChannel::ProcessFrameChunk( if (!is_final_chunk) { VLOG(2) << "Encountered a split control frame, opcode " << opcode; if (incomplete_control_frame_body_) { - // The really horrid case. We need to create a new IOBufferWithSize - // combining the new one and the old one. This should virtually never - // happen. - // TODO(ricea): This algorithm is O(N^2). Use a fixed 127-byte buffer - // instead. - VLOG(3) << "Hit the really horrid case"; - incomplete_control_frame_body_ = - ConcatenateIOBuffers(incomplete_control_frame_body_, data_buffer); + VLOG(3) << "Appending to an existing split control frame."; + AddToIncompleteControlFrameBody(data_buffer); } else { - // The merely horrid case. Store the IOBufferWithSize to use when the - // rest of the control frame arrives. - incomplete_control_frame_body_.swap(data_buffer); + VLOG(3) << "Creating new storage for an incomplete control frame."; + incomplete_control_frame_body_ = new GrowableIOBuffer(); + // This method checks for oversize control frames above, so as long as + // the frame parser is working correctly, this won't overflow. If a bug + // does cause it to overflow, it will CHECK() in + // AddToIncompleteControlFrameBody() without writing outside the buffer. + incomplete_control_frame_body_->SetCapacity(kMaxControlFramePayload); + AddToIncompleteControlFrameBody(data_buffer); } return; // Handle when complete. } if (incomplete_control_frame_body_) { VLOG(2) << "Rejoining a split control frame, opcode " << opcode; - data_buffer = - ConcatenateIOBuffers(incomplete_control_frame_body_, data_buffer); + AddToIncompleteControlFrameBody(data_buffer); + const int body_size = incomplete_control_frame_body_->offset(); + data_buffer = new IOBufferWithSize(body_size); + memcpy(data_buffer->data(), + incomplete_control_frame_body_->StartOfBuffer(), + body_size); incomplete_control_frame_body_ = NULL; // Frame now complete. } } // Apply basic sanity checks to the |payload_length| field from the frame - // header. We can only apply a strict check when we know we have the whole - // frame in one chunk. + // header. A check for exact equality can only be used when the whole frame + // arrives in one chunk. DCHECK_GE(current_frame_header_->payload_length, base::checked_numeric_cast<uint64>(data_buffer->size())); DCHECK(!is_first_chunk || !is_final_chunk || @@ -455,11 +447,24 @@ void WebSocketChannel::ProcessFrameChunk( HandleFrame(opcode, is_first_chunk, is_final_chunk, data_buffer); if (is_final_chunk) { - // Make sure we do not apply this frame header to any future chunks. + // Make sure that this frame header is not applied to any future chunks. current_frame_header_.reset(); } } +void WebSocketChannel::AddToIncompleteControlFrameBody( + const scoped_refptr<IOBufferWithSize>& data_buffer) { + const int new_offset = + incomplete_control_frame_body_->offset() + data_buffer->size(); + CHECK_GE(incomplete_control_frame_body_->capacity(), new_offset) + << "Control frame body larger than frame header indicates; frame parser " + "bug?"; + memcpy(incomplete_control_frame_body_->data(), + data_buffer->data(), + data_buffer->size()); + incomplete_control_frame_body_->set_offset(new_offset); +} + void WebSocketChannel::HandleFrame( const WebSocketFrameHeader::OpCode opcode, bool is_first_chunk, @@ -496,8 +501,8 @@ void WebSocketChannel::HandleFrame( frame_name = "Unknown frame type"; break; } - // SEND_REAL_ERROR makes no difference here, as we won't send another Close - // frame. + // SEND_REAL_ERROR makes no difference here, as FailChannel() won't send + // another Close frame. FailChannel(SEND_REAL_ERROR, kWebSocketErrorProtocolError, frame_name + " received after close"); @@ -515,12 +520,12 @@ void WebSocketChannel::HandleFrame( const char* const data_begin = data_buffer->data(); const char* const data_end = data_begin + data_buffer->size(); const std::vector<char> data(data_begin, data_end); - // TODO(ricea): Handle the (improbable) case when ReadFrames returns far - // more data at once than we want to send in a single IPC (in which case - // we need to buffer the data and return to the event loop with a - // callback to send the rest in 32K chunks). + // TODO(ricea): Handle the case when ReadFrames returns far + // more data at once than should be sent in a single IPC. This needs to + // be handled carefully, as an overloaded IO thread is one possible + // cause of receiving very large chunks. - // Send the received frame to the renderer process. + // Sends the received frame to the renderer process. event_interface_->OnDataFrame( final, is_first_chunk ? opcode : WebSocketFrameHeader::kOpCodeContinuation, @@ -542,7 +547,7 @@ void WebSocketChannel::HandleFrame( case WebSocketFrameHeader::kOpCodePong: VLOG(1) << "Got Pong of size " << data_buffer->size(); - // We do not need to do anything with pong messages. + // There is no need to do anything with pong messages. return; case WebSocketFrameHeader::kOpCodeClose: { @@ -599,10 +604,10 @@ void WebSocketChannel::SendIOBufferWithSize( chunk->final_chunk = true; chunk->data = buffer; if (data_being_sent_) { - // Either the link to the WebSocket server is saturated, or we are simply - // processing a batch of messages. - // TODO(ricea): We need to keep some statistics to work out which situation - // we are in and adjust quota appropriately. + // Either the link to the WebSocket server is saturated, or several messages + // are being sent in a batch. + // TODO(ricea): Keep some statistics to work out the situation and adjust + // quota appropriately. if (!data_to_send_next_) data_to_send_next_.reset(new SendBuffer); data_to_send_next_->AddFrame(chunk.Pass()); @@ -629,13 +634,14 @@ void WebSocketChannel::FailChannel(ExposeError expose, } SendClose(send_code, send_reason); // Sets state_ to SEND_CLOSED } - // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates we should close - // the connection ourselves without waiting for the closing handshake. + // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser + // should close the connection itself without waiting for the closing + // handshake. stream_->Close(); state_ = CLOSED; - // We may be in the middle of processing several chunks. We should not re-use - // the frame header. + // The channel may be in the middle of processing several chunks. It should + // not use this frame header for subsequent chunks. current_frame_header_.reset(); if (old_state != CLOSED) { event_interface_->OnDropChannel(code, reason); diff --git a/chromium/net/websockets/websocket_channel.h b/chromium/net/websockets/websocket_channel.h index d81f83a1821..c997d6adef4 100644 --- a/chromium/net/websockets/websocket_channel.h +++ b/chromium/net/websockets/websocket_channel.h @@ -19,8 +19,10 @@ namespace net { +class GrowableIOBuffer; class URLRequestContext; class WebSocketEventInterface; +class BoundNetLog; // Transport-independent implementation of WebSockets. Implements protocol // semantics that do not depend on the underlying transport. Provides the @@ -72,7 +74,7 @@ class NET_EXPORT WebSocketChannel { // send up to |quota| units of data. void SendFlowControl(int64 quota); - // Start the closing handshake for a client-initiated shutdown of the + // Starts the closing handshake for a client-initiated shutdown of the // connection. There is no API to close the connection without a closing // handshake, but destroying the WebSocketChannel object while connected will // effectively do that. |code| must be in the range 1000-4999. |reason| should @@ -92,16 +94,16 @@ class NET_EXPORT WebSocketChannel { const WebSocketStreamFactory& factory); private: - // We have a simple linear progression of states from FRESHLY_CONSTRUCTED to - // CLOSED, except that the SEND_CLOSED and RECV_CLOSED states may be skipped - // in case of error. + // The object passes through a linear progression of states from + // FRESHLY_CONSTRUCTED to CLOSED, except that the SEND_CLOSED and RECV_CLOSED + // states may be skipped in case of error. enum State { FRESHLY_CONSTRUCTED, CONNECTING, CONNECTED, - SEND_CLOSED, // We have sent a Close frame but not received a Close frame. + SEND_CLOSED, // A Close frame has been sent but not received. RECV_CLOSED, // Used briefly between receiving a Close frame and sending - // the response. Once we have responded, the state changes + // the response. Once the response is sent, the state changes // to CLOSED. CLOSE_WAIT, // The Closing Handshake has completed, but the remote server // has not yet closed the connection. @@ -109,17 +111,18 @@ class NET_EXPORT WebSocketChannel { // has been closed; or the connection is failed. }; - // When failing a channel, we may or may not want to send the real reason for - // failing to the remote server. This enum is used by FailChannel() to - // choose. + // When failing a channel, sometimes it is inappropriate to expose the real + // reason for failing to the remote server. This enum is used by FailChannel() + // to select between sending the real status or a "Going Away" status. enum ExposeError { SEND_REAL_ERROR, SEND_GOING_AWAY, }; - // Our implementation of WebSocketStream::ConnectDelegate. We do not inherit - // from WebSocketStream::ConnectDelegate directly to avoid cluttering our - // public interface with the implementation of those methods, and because the + // Implementation of WebSocketStream::ConnectDelegate for + // WebSocketChannel. WebSocketChannel does not inherit from + // WebSocketStream::ConnectDelegate directly to avoid cluttering the public + // interface with the implementation of those methods, and because the // lifetime of a WebSocketChannel is longer than the lifetime of the // connection process. class ConnectDelegate; @@ -164,7 +167,11 @@ class NET_EXPORT WebSocketChannel { // Processes a single chunk that has been read from the stream. void ProcessFrameChunk(scoped_ptr<WebSocketFrameChunk> chunk); - // Handle a frame that we have received enough of to process. May call + // Appends |data_buffer| to |incomplete_control_frame_body_|. + void AddToIncompleteControlFrameBody( + const scoped_refptr<IOBufferWithSize>& data_buffer); + + // Handles a frame that the object has received enough of to process. May call // event_interface_ methods, send responses to the server, and change the // value of state_. void HandleFrame(const WebSocketFrameHeader::OpCode opcode, @@ -181,13 +188,13 @@ class NET_EXPORT WebSocketChannel { WebSocketFrameHeader::OpCode op_code, const scoped_refptr<IOBufferWithSize>& buffer); - // Perform the "Fail the WebSocket Connection" operation as defined in + // Performs the "Fail the WebSocket Connection" operation as defined in // RFC6455. The supplied code and reason are sent back to the renderer in an // OnDropChannel message. If state_ is CONNECTED then a Close message is sent // to the remote host. If |expose| is SEND_REAL_ERROR then the remote host is - // given the same status code we gave the renderer; otherwise it is sent a - // fixed "Going Away" code. Resets current_frame_header_, closes the - // stream_, and sets state_ to CLOSED. + // given the same status code passed to the renderer; otherwise it is sent a + // fixed "Going Away" code. Resets current_frame_header_, closes the stream_, + // and sets state_ to CLOSED. void FailChannel(ExposeError expose, uint16 code, const std::string& reason); // Sends a Close frame to Start the WebSocket Closing Handshake, or to respond @@ -205,13 +212,13 @@ class NET_EXPORT WebSocketChannel { uint16* code, std::string* reason); - // The URL to which we connect. + // The URL of the remote server. const GURL socket_url_; // The object receiving events. const scoped_ptr<WebSocketEventInterface> event_interface_; - // The WebSocketStream to which we are sending/receiving data. + // The WebSocketStream on which to send and receive data. scoped_ptr<WebSocketStream> stream_; // A data structure containing a vector of frames to be sent and the total @@ -226,32 +233,32 @@ class NET_EXPORT WebSocketChannel { // Destination for the current call to WebSocketStream::ReadFrames ScopedVector<WebSocketFrameChunk> read_frame_chunks_; // Frame header for the frame currently being received. Only non-NULL while we - // are processing the frame. If the frame arrives in multiple chunks, can - // remain non-NULL while we wait for additional chunks to arrive. If the - // header of the frame was invalid, this is set to NULL, the channel is - // failed, and subsequent chunks of the same frame will be ignored. + // are processing the frame. If the frame arrives in multiple chunks, it can + // remain non-NULL until additional chunks arrive. If the header of the frame + // was invalid, this is set to NULL, the channel is failed, and subsequent + // chunks of the same frame will be ignored. scoped_ptr<WebSocketFrameHeader> current_frame_header_; // Handle to an in-progress WebSocketStream creation request. Only non-NULL // during the connection process. scoped_ptr<WebSocketStreamRequest> stream_request_; - // Although it will almost never happen in practice, we can be passed an - // incomplete control frame, in which case we need to keep the data around - // long enough to reassemble it. This variable will be NULL the rest of the - // time. - scoped_refptr<IOBufferWithSize> incomplete_control_frame_body_; - // The point at which we give the renderer a quota refresh (quota units). - // "quota units" are currently bytes. TODO(ricea): Update the definition of - // quota units when necessary. + // Although it should rarely happen in practice, a control frame can arrive + // broken into chunks. This variable provides storage for a partial control + // frame until the rest arrives. It will be NULL the rest of the time. + scoped_refptr<GrowableIOBuffer> incomplete_control_frame_body_; + // If the renderer's send quota reaches this level, it is sent a quota + // refresh. "quota units" are currently bytes. TODO(ricea): Update the + // definition of quota units when necessary. int send_quota_low_water_mark_; - // The amount which we refresh the quota to when it reaches the - // low_water_mark (quota units). + // The level the quota is refreshed to when it reaches the low_water_mark + // (quota units). int send_quota_high_water_mark_; // The current amount of quota that the renderer has available for sending // on this logical channel (quota units). int current_send_quota_; - // Storage for the status code and reason from the time we receive the Close - // frame until the connection is closed and we can call OnDropChannel(). + // Storage for the status code and reason from the time the Close frame + // arrives until the connection is closed and they are passed to + // OnDropChannel(). uint16 closing_code_; std::string closing_reason_; diff --git a/chromium/net/websockets/websocket_deflater.cc b/chromium/net/websockets/websocket_deflater.cc new file mode 100644 index 00000000000..c5a4e6a31f1 --- /dev/null +++ b/chromium/net/websockets/websocket_deflater.cc @@ -0,0 +1,128 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_deflater.h" + +#include <string.h> +#include <algorithm> +#include <deque> +#include <vector> + +#include "base/logging.h" +#include "net/base/io_buffer.h" +#include "third_party/zlib/zlib.h" + +namespace net { + +WebSocketDeflater::WebSocketDeflater(ContextTakeOverMode mode) + : mode_(mode), are_bytes_added_(false) {} + +WebSocketDeflater::~WebSocketDeflater() { + if (stream_) { + deflateEnd(stream_.get()); + stream_.reset(NULL); + } +} + +bool WebSocketDeflater::Initialize(int window_bits) { + DCHECK(!stream_); + stream_.reset(new z_stream); + + DCHECK_LE(8, window_bits); + DCHECK_GE(15, window_bits); + memset(stream_.get(), 0, sizeof(*stream_)); + int result = deflateInit2(stream_.get(), + Z_DEFAULT_COMPRESSION, + Z_DEFLATED, + -window_bits, // Negative value for raw deflate + 8, // default mem level + Z_DEFAULT_STRATEGY); + if (result != Z_OK) { + deflateEnd(stream_.get()); + stream_.reset(); + return false; + } + const size_t kFixedBufferSize = 4096; + fixed_buffer_.resize(kFixedBufferSize); + return true; +} + +bool WebSocketDeflater::AddBytes(const char* data, size_t size) { + if (!size) + return true; + + are_bytes_added_ = true; + stream_->next_in = reinterpret_cast<Bytef*>(const_cast<char*>(data)); + stream_->avail_in = size; + + int result = Deflate(Z_NO_FLUSH); + DCHECK(result != Z_BUF_ERROR || !stream_->avail_in); + return result == Z_BUF_ERROR; +} + +bool WebSocketDeflater::Finish() { + if (!are_bytes_added_) { + // Since consecutive calls of deflate with Z_SYNC_FLUSH and no input + // lead to an error, we create and return the output for the empty input + // manually. + buffer_.push_back('\x02'); + buffer_.push_back('\x00'); + ResetContext(); + return true; + } + stream_->next_in = NULL; + stream_->avail_in = 0; + + int result = Deflate(Z_SYNC_FLUSH); + // Deflate returning Z_BUF_ERROR means that it's successfully flushed and + // blocked for input data. + if (result != Z_BUF_ERROR) { + ResetContext(); + return false; + } + // Remove 4 octets from the tail as the specification requires. + if (CurrentOutputSize() < 4) { + ResetContext(); + return false; + } + buffer_.resize(buffer_.size() - 4); + ResetContext(); + return true; +} + +void WebSocketDeflater::PushSyncMark() { + DCHECK(!are_bytes_added_); + const char data[] = {'\x00', '\x00', '\xff', '\xff'}; + buffer_.insert(buffer_.end(), &data[0], &data[sizeof(data)]); +} + +scoped_refptr<IOBufferWithSize> WebSocketDeflater::GetOutput(size_t size) { + std::deque<char>::iterator begin = buffer_.begin(); + std::deque<char>::iterator end = begin + std::min(size, buffer_.size()); + + scoped_refptr<IOBufferWithSize> result = new IOBufferWithSize(end - begin); + std::copy(begin, end, result->data()); + buffer_.erase(begin, end); + return result; +} + +void WebSocketDeflater::ResetContext() { + if (mode_ == DO_NOT_TAKE_OVER_CONTEXT) + deflateReset(stream_.get()); + are_bytes_added_ = false; +} + +int WebSocketDeflater::Deflate(int flush) { + int result = Z_OK; + do { + stream_->next_out = reinterpret_cast<Bytef*>(&fixed_buffer_[0]); + stream_->avail_out = fixed_buffer_.size(); + result = deflate(stream_.get(), flush); + size_t size = fixed_buffer_.size() - stream_->avail_out; + buffer_.insert(buffer_.end(), &fixed_buffer_[0], &fixed_buffer_[size]); + } while (result == Z_OK); + return result; +} + +} // namespace net diff --git a/chromium/net/websockets/websocket_deflater.h b/chromium/net/websockets/websocket_deflater.h new file mode 100644 index 00000000000..da85bfec912 --- /dev/null +++ b/chromium/net/websockets/websocket_deflater.h @@ -0,0 +1,75 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_DEFLATER_H_ +#define NET_WEBSOCKETS_WEBSOCKET_DEFLATER_H_ + +#include <deque> +#include <vector> + +#include "base/basictypes.h" +#include "base/memory/ref_counted.h" +#include "base/memory/scoped_ptr.h" +#include "net/base/net_export.h" + +extern "C" struct z_stream_s; + +namespace net { + +class IOBufferWithSize; + +class NET_EXPORT_PRIVATE WebSocketDeflater { + public: + enum ContextTakeOverMode { + DO_NOT_TAKE_OVER_CONTEXT, + TAKE_OVER_CONTEXT, + }; + + explicit WebSocketDeflater(ContextTakeOverMode mode); + ~WebSocketDeflater(); + + // Returns true if there is no error and false otherwise. + // This function must be called exactly once before calling any of + // following methods. + // |window_bits| must be between 8 and 15 (both inclusive). + bool Initialize(int window_bits); + + // Adds bytes to |stream_|. + // Returns true if there is no error and false otherwise. + bool AddBytes(const char* data, size_t size); + + // Flushes the current processing data. + // Returns true if there is no error and false otherwise. + bool Finish(); + + // Pushes "\x00\x00\xff\xff" to the end of the buffer. + void PushSyncMark(); + + // Returns the current deflated output. + // If the current output is larger than |size| bytes, + // returns the first |size| bytes of the current output. + // The returned bytes will be dropped from the current output and never be + // returned thereafter. + scoped_refptr<IOBufferWithSize> GetOutput(size_t size); + + // Returns the size of the current deflated output. + size_t CurrentOutputSize() const { return buffer_.size(); } + + private: + void ResetContext(); + int Deflate(int flush); + + scoped_ptr<z_stream_s> stream_; + ContextTakeOverMode mode_; + std::deque<char> buffer_; + std::vector<char> fixed_buffer_; + // true if bytes were added after last Finish(). + bool are_bytes_added_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketDeflater); +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_DEFLATER_H_ diff --git a/chromium/net/websockets/websocket_deflater_test.cc b/chromium/net/websockets/websocket_deflater_test.cc new file mode 100644 index 00000000000..03b8a3d7c52 --- /dev/null +++ b/chromium/net/websockets/websocket_deflater_test.cc @@ -0,0 +1,138 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_deflater.h" + +#include <string> + +#include "base/memory/ref_counted.h" +#include "net/base/io_buffer.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +std::string ToString(IOBufferWithSize* buffer) { + return std::string(buffer->data(), buffer->size()); +} + +TEST(WebSocketDeflaterTest, Construct) { + WebSocketDeflater deflater(WebSocketDeflater::TAKE_OVER_CONTEXT); + deflater.Initialize(8); + ASSERT_EQ(0u, deflater.CurrentOutputSize()); + ASSERT_TRUE(deflater.Finish()); + scoped_refptr<IOBufferWithSize> actual = + deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\x02\00", 2), ToString(actual.get())); + ASSERT_EQ(0u, deflater.CurrentOutputSize()); +} + +TEST(WebSocketDeflaterTest, DeflateHelloTakeOverContext) { + WebSocketDeflater deflater(WebSocketDeflater::TAKE_OVER_CONTEXT); + deflater.Initialize(15); + scoped_refptr<IOBufferWithSize> actual1, actual2; + + ASSERT_TRUE(deflater.AddBytes("Hello", 5)); + ASSERT_TRUE(deflater.Finish()); + actual1 = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\xf2\x48\xcd\xc9\xc9\x07\x00", 7), + ToString(actual1.get())); + + ASSERT_TRUE(deflater.AddBytes("Hello", 5)); + ASSERT_TRUE(deflater.Finish()); + actual2 = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\xf2\x00\x11\x00\x00", 5), ToString(actual2.get())); +} + +TEST(WebSocketDeflaterTest, DeflateHelloDoNotTakeOverContext) { + WebSocketDeflater deflater(WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT); + deflater.Initialize(15); + scoped_refptr<IOBufferWithSize> actual1, actual2; + + ASSERT_TRUE(deflater.AddBytes("Hello", 5)); + ASSERT_TRUE(deflater.Finish()); + actual1 = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\xf2\x48\xcd\xc9\xc9\x07\x00", 7), + ToString(actual1.get())); + + ASSERT_TRUE(deflater.AddBytes("Hello", 5)); + ASSERT_TRUE(deflater.Finish()); + actual2 = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\xf2\x48\xcd\xc9\xc9\x07\x00", 7), + ToString(actual2.get())); +} + +TEST(WebSocketDeflaterTest, MultipleAddBytesCalls) { + WebSocketDeflater deflater(WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT); + deflater.Initialize(15); + std::string input(32, 'a'); + scoped_refptr<IOBufferWithSize> actual; + + for (size_t i = 0; i < input.size(); ++i) { + ASSERT_TRUE(deflater.AddBytes(&input[i], 1)); + } + ASSERT_TRUE(deflater.Finish()); + actual = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\x4a\x4c\xc4\x0f\x00\x00", 6), ToString(actual.get())); +} + +TEST(WebSocketDeflaterTest, GetMultipleDeflatedOutput) { + WebSocketDeflater deflater(WebSocketDeflater::TAKE_OVER_CONTEXT); + deflater.Initialize(15); + scoped_refptr<IOBufferWithSize> actual; + + ASSERT_TRUE(deflater.AddBytes("Hello", 5)); + ASSERT_TRUE(deflater.Finish()); + deflater.PushSyncMark(); + ASSERT_TRUE(deflater.Finish()); + deflater.PushSyncMark(); + ASSERT_TRUE(deflater.AddBytes("Hello", 5)); + ASSERT_TRUE(deflater.Finish()); + + actual = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("\xf2\x48\xcd\xc9\xc9\x07\x00\x00\x00\xff\xff" + "\x02\x00\x00\x00\xff\xff" + "\xf2\x00\x11\x00\x00", 22), + ToString(actual.get())); + ASSERT_EQ(0u, deflater.CurrentOutputSize()); +} + +TEST(WebSocketDeflaterTest, WindowBits8) { + WebSocketDeflater deflater(WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT); + deflater.Initialize(8); + // Set the head and tail of |input| so that back-reference + // can be used if the window size is sufficiently-large. + const std::string word = "Chromium"; + std::string input = word + std::string(256, 'a') + word; + scoped_refptr<IOBufferWithSize> actual; + + ASSERT_TRUE(deflater.AddBytes(input.data(), input.size())); + ASSERT_TRUE(deflater.Finish()); + actual = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ(std::string("r\xce(\xca\xcf\xcd,\xcdM\x1c\xe1\xc0\x39\xa3" + "(?7\xb3\x34\x17\x00", 21), + ToString(actual.get())); +} + +TEST(WebSocketDeflaterTest, WindowBits10) { + WebSocketDeflater deflater(WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT); + deflater.Initialize(10); + // Set the head and tail of |input| so that back-reference + // can be used if the window size is sufficiently-large. + const std::string word = "Chromium"; + std::string input = word + std::string(256, 'a') + word; + scoped_refptr<IOBufferWithSize> actual; + + ASSERT_TRUE(deflater.AddBytes(input.data(), input.size())); + ASSERT_TRUE(deflater.Finish()); + actual = deflater.GetOutput(deflater.CurrentOutputSize()); + EXPECT_EQ( + std::string("r\xce(\xca\xcf\xcd,\xcdM\x1c\xe1\xc0\x19\x1a\x0e\0\0", 17), + ToString(actual.get())); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/websockets/websocket_errors_unittest.cc b/chromium/net/websockets/websocket_errors_test.cc index 1a4fca61453..5e48e20e831 100644 --- a/chromium/net/websockets/websocket_errors_unittest.cc +++ b/chromium/net/websockets/websocket_errors_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2013 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/chromium/net/websockets/websocket_extension.cc b/chromium/net/websockets/websocket_extension.cc new file mode 100644 index 00000000000..edcd8e8657f --- /dev/null +++ b/chromium/net/websockets/websocket_extension.cc @@ -0,0 +1,43 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_extension.h" + +#include <string> + +#include "base/logging.h" + +namespace net { + +WebSocketExtension::Parameter::Parameter(const std::string& name) + : name_(name) {} + +WebSocketExtension::Parameter::Parameter(const std::string& name, + const std::string& value) + : name_(name), value_(value) { + DCHECK(!value.empty()); +} + +bool WebSocketExtension::Parameter::Equals(const Parameter& other) const { + return name_ == other.name_ && value_ == other.value_; +} + +WebSocketExtension::WebSocketExtension() {} + +WebSocketExtension::WebSocketExtension(const std::string& name) + : name_(name) {} + +WebSocketExtension::~WebSocketExtension() {} + +bool WebSocketExtension::Equals(const WebSocketExtension& other) const { + if (name_ != other.name_) return false; + if (parameters_.size() != other.parameters_.size()) return false; + for (size_t i = 0; i < other.parameters_.size(); ++i) { + if (!parameters_[i].Equals(other.parameters_[i])) + return false; + } + return true; +} + +} // namespace net diff --git a/chromium/net/websockets/websocket_extension.h b/chromium/net/websockets/websocket_extension.h new file mode 100644 index 00000000000..5af4023869f --- /dev/null +++ b/chromium/net/websockets/websocket_extension.h @@ -0,0 +1,57 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_EXTENSION_H_ +#define NET_WEBSOCKETS_WEBSOCKET_EXTENSION_H_ + +#include <string> +#include <vector> + +#include "net/base/net_export.h" + +namespace net { + +// A WebSocketExtension instance represents a WebSocket extension specified +// in RFC6455. +class NET_EXPORT_PRIVATE WebSocketExtension { + public: + // Note that RFC6455 does not allow a parameter with an empty value. + class NET_EXPORT_PRIVATE Parameter { + public: + // Construct a parameter which does not have a value. + explicit Parameter(const std::string& name); + // Construct a parameter with a non-empty value. + Parameter(const std::string& name, const std::string& value); + + bool HasValue() const { return !value_.empty(); } + const std::string& name() const { return name_; } + const std::string& value() const { return value_; } + bool Equals(const Parameter& other) const; + + // The default copy constructor and the assignment operator are defined: + // we need them. + private: + std::string name_; + std::string value_; + }; + + WebSocketExtension(); + explicit WebSocketExtension(const std::string& name); + ~WebSocketExtension(); + + void Add(const Parameter& parameter) { parameters_.push_back(parameter); } + const std::string& name() const { return name_; } + const std::vector<Parameter>& parameters() const { return parameters_; } + bool Equals(const WebSocketExtension& other) const; + + // The default copy constructor and the assignment operator are defined: + // we need them. + private: + std::string name_; + std::vector<Parameter> parameters_; +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_EXTENSION_H_ diff --git a/chromium/net/websockets/websocket_extension_parser.cc b/chromium/net/websockets/websocket_extension_parser.cc new file mode 100644 index 00000000000..28a2db16f27 --- /dev/null +++ b/chromium/net/websockets/websocket_extension_parser.cc @@ -0,0 +1,158 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_extension_parser.h" + +#include "base/strings/string_util.h" + +namespace net { + +WebSocketExtensionParser::WebSocketExtensionParser() {} + +WebSocketExtensionParser::~WebSocketExtensionParser() {} + +void WebSocketExtensionParser::Parse(const char* data, size_t size) { + current_ = data; + end_ = data + size; + has_error_ = false; + + ConsumeExtension(&extension_); + if (has_error_) return; + ConsumeSpaces(); + has_error_ = has_error_ || (current_ != end_); +} + +void WebSocketExtensionParser::Consume(char c) { + DCHECK(!has_error_); + ConsumeSpaces(); + DCHECK(!has_error_); + if (current_ == end_ || c != current_[0]) { + has_error_ = true; + return; + } + ++current_; +} + +void WebSocketExtensionParser::ConsumeExtension(WebSocketExtension* extension) { + DCHECK(!has_error_); + base::StringPiece name; + ConsumeToken(&name); + if (has_error_) return; + *extension = WebSocketExtension(name.as_string()); + + while (ConsumeIfMatch(';')) { + WebSocketExtension::Parameter parameter((std::string())); + ConsumeExtensionParameter(¶meter); + if (has_error_) return; + extension->Add(parameter); + } +} + +void WebSocketExtensionParser::ConsumeExtensionParameter( + WebSocketExtension::Parameter* parameter) { + DCHECK(!has_error_); + base::StringPiece name, value; + std::string value_string; + + ConsumeToken(&name); + if (has_error_) return; + if (!ConsumeIfMatch('=')) { + *parameter = WebSocketExtension::Parameter(name.as_string()); + return; + } + + if (Lookahead('\"')) { + ConsumeQuotedToken(&value_string); + } else { + ConsumeToken(&value); + value_string = value.as_string(); + } + if (has_error_) return; + *parameter = WebSocketExtension::Parameter(name.as_string(), value_string); +} + +void WebSocketExtensionParser::ConsumeToken(base::StringPiece* token) { + DCHECK(!has_error_); + ConsumeSpaces(); + DCHECK(!has_error_); + const char* head = current_; + while (current_ < end_ && + !IsControl(current_[0]) && !IsSeparator(current_[0])) + ++current_; + if (current_ == head) { + has_error_ = true; + return; + } + *token = base::StringPiece(head, current_ - head); +} + +void WebSocketExtensionParser::ConsumeQuotedToken(std::string* token) { + DCHECK(!has_error_); + Consume('"'); + if (has_error_) return; + *token = ""; + while (current_ < end_ && !IsControl(current_[0])) { + if (UnconsumedBytes() >= 2 && current_[0] == '\\') { + char next = current_[1]; + if (IsControl(next) || IsSeparator(next)) break; + *token += next; + current_ += 2; + } else if (IsSeparator(current_[0])) { + break; + } else { + *token += current_[0]; + ++current_; + } + } + // We can't use Consume here because we don't want to consume spaces. + if (current_ < end_ && current_[0] == '"') + ++current_; + else + has_error_ = true; + has_error_ = has_error_ || token->empty(); +} + +void WebSocketExtensionParser::ConsumeSpaces() { + DCHECK(!has_error_); + while (current_ < end_ && (current_[0] == ' ' || current_[0] == '\t')) + ++current_; + return; +} + +bool WebSocketExtensionParser::Lookahead(char c) { + DCHECK(!has_error_); + const char* head = current_; + + Consume(c); + bool result = !has_error_; + current_ = head; + has_error_ = false; + return result; +} + +bool WebSocketExtensionParser::ConsumeIfMatch(char c) { + DCHECK(!has_error_); + const char* head = current_; + + Consume(c); + if (has_error_) { + current_ = head; + has_error_ = false; + return false; + } + return true; +} + +// static +bool WebSocketExtensionParser::IsControl(char c) { + return (0 <= c && c <= 31) || c == 127; +} + +// static +bool WebSocketExtensionParser::IsSeparator(char c) { + const char separators[] = "()<>@,;:\\\"/[]?={} \t"; + return strchr(separators, c) != NULL; +} + +} // namespace net diff --git a/chromium/net/websockets/websocket_extension_parser.h b/chromium/net/websockets/websocket_extension_parser.h new file mode 100644 index 00000000000..ef7fe036657 --- /dev/null +++ b/chromium/net/websockets/websocket_extension_parser.h @@ -0,0 +1,59 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_EXTENSION_PARSER_H_ +#define NET_WEBSOCKETS_WEBSOCKET_EXTENSION_PARSER_H_ + +#include <string> + +#include "base/strings/string_piece.h" +#include "net/base/net_export.h" +#include "net/websockets/websocket_extension.h" + +namespace net { + +class NET_EXPORT_PRIVATE WebSocketExtensionParser { + public: + WebSocketExtensionParser(); + ~WebSocketExtensionParser(); + + // Parses the given string as a WebSocket extension header value. + // This parser assumes some preprocesses are made. + // - The parser parses single extension at a time. This means that + // the parser parses |extension| in RFC6455 9.1, not |extension-list|. + // - There is no newline characters in the input. LWS-concatenation must + // have already been done. + void Parse(const char* data, size_t size); + void Parse(const std::string& data) { + Parse(data.data(), data.size()); + } + + bool has_error() const { return has_error_; } + const WebSocketExtension& extension() const { return extension_; } + + private: + void Consume(char c); + void ConsumeExtension(WebSocketExtension* extension); + void ConsumeExtensionParameter(WebSocketExtension::Parameter* parameter); + void ConsumeToken(base::StringPiece* token); + void ConsumeQuotedToken(std::string* token); + void ConsumeSpaces(); + bool Lookahead(char c); + bool ConsumeIfMatch(char c); + size_t UnconsumedBytes() const { return end_ - current_; } + + static bool IsControl(char c); + static bool IsSeparator(char c); + + const char* current_; + const char* end_; + bool has_error_; + WebSocketExtension extension_; + + DISALLOW_COPY_AND_ASSIGN(WebSocketExtensionParser); +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_EXTENSION_PARSER_H_ diff --git a/chromium/net/websockets/websocket_extension_parser_test.cc b/chromium/net/websockets/websocket_extension_parser_test.cc new file mode 100644 index 00000000000..dc7dc859d83 --- /dev/null +++ b/chromium/net/websockets/websocket_extension_parser_test.cc @@ -0,0 +1,122 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_extension_parser.h" + +#include <string> + +#include "net/websockets/websocket_extension.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace net { + +namespace { + +TEST(WebSocketExtensionParserTest, ParseEmpty) { + WebSocketExtensionParser parser; + parser.Parse("", 0); + + EXPECT_TRUE(parser.has_error()); +} + +TEST(WebSocketExtensionParserTest, ParseSimple) { + WebSocketExtensionParser parser; + WebSocketExtension expected("foo"); + + parser.Parse("foo"); + + ASSERT_FALSE(parser.has_error()); + EXPECT_TRUE(expected.Equals(parser.extension())); +} + +TEST(WebSocketExtensionParserTest, ParseOneExtensionWithOneParamWithoutValue) { + WebSocketExtensionParser parser; + WebSocketExtension expected("foo"); + expected.Add(WebSocketExtension::Parameter("bar")); + + parser.Parse("\tfoo ; bar"); + + ASSERT_FALSE(parser.has_error()); + EXPECT_TRUE(expected.Equals(parser.extension())); +} + +TEST(WebSocketExtensionParserTest, ParseOneExtensionWithOneParamWithValue) { + WebSocketExtensionParser parser; + WebSocketExtension expected("foo"); + expected.Add(WebSocketExtension::Parameter("bar", "baz")); + + parser.Parse("foo ; bar= baz\t"); + + ASSERT_FALSE(parser.has_error()); + EXPECT_TRUE(expected.Equals(parser.extension())); +} + +TEST(WebSocketExtensionParserTest, ParseOneExtensionWithParams) { + WebSocketExtensionParser parser; + WebSocketExtension expected("foo"); + expected.Add(WebSocketExtension::Parameter("bar", "baz")); + expected.Add(WebSocketExtension::Parameter("hoge", "fuga")); + + parser.Parse("foo ; bar= baz;\t \thoge\t\t=fuga"); + + ASSERT_FALSE(parser.has_error()); + EXPECT_TRUE(expected.Equals(parser.extension())); +} + +TEST(WebSocketExtensionParserTest, InvalidPatterns) { + const char* patterns[] = { + "fo\ao", // control in extension name + "fo\x01o", // control in extension name + "fo<o", // separator in extension name + "foo/", // separator in extension name + ";bar", // empty extension name + "foo bar", // missing ';' + "foo;", // extension parameter without name and value + "foo; b\ar", // control in parameter name + "foo; b\x7fr", // control in parameter name + "foo; b[r", // separator in parameter name + "foo; ba:", // separator in parameter name + "foo; =baz", // empty parameter name + "foo; bar=", // empty parameter value + "foo; =", // empty parameter name and value + "foo; bar=b\x02z", // control in parameter value + "foo; bar=b@z", // separator in parameter value + "foo; bar=b\\z", // separator in parameter value + "foo; bar=b?z", // separator in parameter value + "\"foo\"", // quoted extension name + "foo; \"bar\"", // quoted parameter name + "foo; bar=\"\a2\"", // control in quoted parameter value + "foo; bar=\"b@z\"", // separator in quoted parameter value + "foo; bar=\"b\\\\z\"", // separator in quoted parameter value + "foo; bar=\"\"", // quoted empty parameter value + "foo; bar=\"baz", // unterminated quoted string + "foo; bar=\"baz \"", // space in quoted string + "foo; bar baz", // mising '=' + "foo; bar - baz", // '-' instead of '=' (note: "foo; bar-baz" is valid). + "foo; bar=\r\nbaz", // CRNL not followed by a space + "foo; bar=\r\n baz", // CRNL followed by a space + "foo, bar" // multiple extensions + }; + + for (size_t i = 0; i < arraysize(patterns); ++i) { + WebSocketExtensionParser parser; + parser.Parse(patterns[i]); + EXPECT_TRUE(parser.has_error()); + } +} + +TEST(WebSocketExtensionParserTest, QuotedParameterValue) { + WebSocketExtensionParser parser; + WebSocketExtension expected("foo"); + expected.Add(WebSocketExtension::Parameter("bar", "baz")); + + parser.Parse("foo; bar = \"ba\\z\" "); + + ASSERT_FALSE(parser.has_error()); + EXPECT_TRUE(expected.Equals(parser.extension())); +} + +} // namespace + +} // namespace net diff --git a/chromium/net/websockets/websocket_frame_parser_unittest.cc b/chromium/net/websockets/websocket_frame_parser_test.cc index b2a804027a7..4eb036db28e 100644 --- a/chromium/net/websockets/websocket_frame_parser_unittest.cc +++ b/chromium/net/websockets/websocket_frame_parser_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/chromium/net/websockets/websocket_frame_unittest.cc b/chromium/net/websockets/websocket_frame_test.cc index 1652b3b24f9..97fac03e12e 100644 --- a/chromium/net/websockets/websocket_frame_unittest.cc +++ b/chromium/net/websockets/websocket_frame_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/chromium/net/websockets/websocket_handshake_constants.cc b/chromium/net/websockets/websocket_handshake_constants.cc new file mode 100644 index 00000000000..357b1349ae9 --- /dev/null +++ b/chromium/net/websockets/websocket_handshake_constants.cc @@ -0,0 +1,36 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_handshake_constants.h" + +namespace net { +namespace websockets { + +const char* const kHttpProtocolVersion = "HTTP/1.1"; + +const size_t kRawChallengeLength = 16; + +const char* const kSecWebSocketProtocol = "Sec-WebSocket-Protocol"; +const char* const kSecWebSocketExtensions = "Sec-WebSocket-Extensions"; +const char* const kSecWebSocketKey = "Sec-WebSocket-Key"; +const char* const kSecWebSocketAccept = "Sec-WebSocket-Accept"; +const char* const kSecWebSocketVersion = "Sec-WebSocket-Version"; + +const char* const kUpgrade = "Upgrade"; +const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + +const char* const kSecWebSocketProtocolSpdy3 = ":sec-websocket-protocol"; +const char* const kSecWebSocketExtensionsSpdy3 = ":sec-websocket-extensions"; + +const char* const kSecWebSocketProtocolLowercase = + kSecWebSocketProtocolSpdy3 + 1; +const char* const kSecWebSocketExtensionsLowercase = + kSecWebSocketExtensionsSpdy3 + 1; +const char* const kSecWebSocketKeyLowercase = "sec-websocket-key"; +const char* const kSecWebSocketVersionLowercase = "sec-websocket-version"; +const char* const kUpgradeLowercase = "upgrade"; +const char* const kWebSocketLowercase = "websocket"; + +} // namespace websockets +} // namespace net diff --git a/chromium/net/websockets/websocket_handshake_constants.h b/chromium/net/websockets/websocket_handshake_constants.h new file mode 100644 index 00000000000..0cf67a8ac7f --- /dev/null +++ b/chromium/net/websockets/websocket_handshake_constants.h @@ -0,0 +1,84 @@ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// A set of common constants that are needed for the WebSocket handshake. +// In general, you should prefer using these constants to literal strings, +// except in tests. +// +// These constants cannot be used in files that are compiled on iOS, because +// this file is not compiled on iOS. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_CONSTANTS_H_ +#define NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_CONSTANTS_H_ + +#include "base/basictypes.h" + +// This file plases constants inside the ::net::websockets namespace to avoid +// risk of collisions with other symbols in libnet. +namespace net { +namespace websockets { + +// "HTTP/1.1" +// RFC6455 only requires HTTP/1.1 "or better" but in practice an HTTP version +// other than 1.1 should not occur in a WebSocket handshake. +extern const char* const kHttpProtocolVersion; + +// The Sec-WebSockey-Key challenge is 16 random bytes, base64 encoded. +extern const size_t kRawChallengeLength; + +// "Sec-WebSocket-Protocol" +extern const char* const kSecWebSocketProtocol; + +// "Sec-WebSocket-Extensions" +extern const char* const kSecWebSocketExtensions; + +// "Sec-WebSocket-Key" +extern const char* const kSecWebSocketKey; + +// "Sec-WebSocket-Accept" +extern const char* const kSecWebSocketAccept; + +// "Sec-WebSocket-Version" +extern const char* const kSecWebSocketVersion; + +// "Upgrade" +extern const char* const kUpgrade; + +// "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" as defined in section 4.1 of +// RFC6455. +extern const char* const kWebSocketGuid; + +// Colon-prefixed lowercase headers for SPDY3. + +// ":sec-websocket-protocol" +extern const char* const kSecWebSocketProtocolSpdy3; + +// ":sec-websocket-extensions" +extern const char* const kSecWebSocketExtensionsSpdy3; + +// Some parts of the code require lowercase versions of the header names in +// order to do case-insensitive comparisons, or because of SPDY. +// "sec-websocket-protocol" +extern const char* const kSecWebSocketProtocolLowercase; + +// "sec-websocket-extensions" +extern const char* const kSecWebSocketExtensionsLowercase; + +// "sec-webSocket-key" +extern const char* const kSecWebSocketKeyLowercase; + +// "sec-websocket-version" +extern const char* const kSecWebSocketVersionLowercase; + +// "upgrade" +extern const char* const kUpgradeLowercase; + +// "websocket", as used in the "Upgrade:" header. This is always lowercase +// (except in obsolete versions of the protocol). +extern const char* const kWebSocketLowercase; + +} // namespace websockets +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_CONSTANTS_H_ diff --git a/chromium/net/websockets/websocket_handshake_handler.cc b/chromium/net/websockets/websocket_handshake_handler.cc index e69bb6dba48..e5b66418afe 100644 --- a/chromium/net/websockets/websocket_handshake_handler.cc +++ b/chromium/net/websockets/websocket_handshake_handler.cc @@ -14,10 +14,13 @@ #include "base/strings/string_tokenizer.h" #include "base/strings/string_util.h" #include "base/strings/stringprintf.h" +#include "net/http/http_request_headers.h" #include "net/http/http_response_headers.h" #include "net/http/http_util.h" +#include "net/websockets/websocket_handshake_constants.h" #include "url/gurl.h" +namespace net { namespace { const size_t kRequestKey3Size = 8U; @@ -27,9 +30,6 @@ const size_t kResponseKeySize = 16U; // require sending "key3" or "response key" data after headers. const int kMinVersionOfHybiNewHandshake = 4; -// Used when we calculate the value of Sec-WebSocket-Accept. -const char* const kWebSocketGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - void ParseHandshakeHeader( const char* handshake_message, int len, std::string* status_line, @@ -117,8 +117,8 @@ std::string FilterHeaders( int GetVersionFromRequest(const std::string& request_headers) { std::vector<std::string> values; - const char* const headers_to_get[2] = { "sec-websocket-version", - "sec-websocket-draft" }; + const char* const headers_to_get[2] = { + websockets::kSecWebSocketVersionLowercase, "sec-websocket-draft"}; FetchHeaders(request_headers, headers_to_get, 2, &values); DCHECK_LE(values.size(), 1U); if (values.empty()) @@ -130,9 +130,27 @@ int GetVersionFromRequest(const std::string& request_headers) { return version; } -} // namespace +// Append a header to a string. Equivalent to +// response_message += header + ": " + value + "\r\n" +// but avoids unnecessary allocations and copies. +void AppendHeader(const base::StringPiece& header, + const base::StringPiece& value, + std::string* response_message) { + static const char kColonSpace[] = ": "; + const size_t kColonSpaceSize = sizeof(kColonSpace) - 1; + static const char kCrNl[] = "\r\n"; + const size_t kCrNlSize = sizeof(kCrNl) - 1; + + size_t extra_size = + header.size() + kColonSpaceSize + value.size() + kCrNlSize; + response_message->reserve(response_message->size() + extra_size); + response_message->append(header.begin(), header.end()); + response_message->append(kColonSpace, kColonSpace + kColonSpaceSize); + response_message->append(value.begin(), value.end()); + response_message->append(kCrNl, kCrNl + kCrNlSize); +} -namespace net { +} // namespace namespace internal { @@ -242,15 +260,15 @@ HttpRequestInfo WebSocketHandshakeRequestHandler::GetRequestInfo( request_info.extra_headers.Clear(); request_info.extra_headers.AddHeadersFromString(headers_); - request_info.extra_headers.RemoveHeader("Upgrade"); - request_info.extra_headers.RemoveHeader("Connection"); + request_info.extra_headers.RemoveHeader(websockets::kUpgrade); + request_info.extra_headers.RemoveHeader(HttpRequestHeaders::kConnection); if (protocol_version_ >= kMinVersionOfHybiNewHandshake) { std::string key; - bool header_present = - request_info.extra_headers.GetHeader("Sec-WebSocket-Key", &key); + bool header_present = request_info.extra_headers.GetHeader( + websockets::kSecWebSocketKey, &key); DCHECK(header_present); - request_info.extra_headers.RemoveHeader("Sec-WebSocket-Key"); + request_info.extra_headers.RemoveHeader(websockets::kSecWebSocketKey); *challenge = key; } else { challenge->clear(); @@ -294,33 +312,34 @@ bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock( HttpUtil::HeadersIterator iter(headers_.begin(), headers_.end(), "\r\n"); while (iter.GetNext()) { - if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), "upgrade") || - LowerCaseEqualsASCII(iter.name_begin(), + if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), - "connection") || + websockets::kUpgradeLowercase) || + LowerCaseEqualsASCII( + iter.name_begin(), iter.name_end(), "connection") || LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), - "sec-websocket-version")) { + websockets::kSecWebSocketVersionLowercase)) { // These headers must be ignored. continue; } else if (LowerCaseEqualsASCII(iter.name_begin(), iter.name_end(), - "sec-websocket-key")) { + websockets::kSecWebSocketKeyLowercase)) { *challenge = iter.values(); // Sec-WebSocket-Key is not sent to the server. continue; - } else if (LowerCaseEqualsASCII(iter.name_begin(), - iter.name_end(), - "host") || - LowerCaseEqualsASCII(iter.name_begin(), - iter.name_end(), - "origin") || - LowerCaseEqualsASCII(iter.name_begin(), - iter.name_end(), - "sec-websocket-protocol") || - LowerCaseEqualsASCII(iter.name_begin(), - iter.name_end(), - "sec-websocket-extensions")) { + } else if (LowerCaseEqualsASCII( + iter.name_begin(), iter.name_end(), "host") || + LowerCaseEqualsASCII( + iter.name_begin(), iter.name_end(), "origin") || + LowerCaseEqualsASCII( + iter.name_begin(), + iter.name_end(), + websockets::kSecWebSocketProtocolLowercase) || + LowerCaseEqualsASCII( + iter.name_begin(), + iter.name_end(), + websockets::kSecWebSocketExtensionsLowercase)) { // TODO(toyoshim): Some WebSocket extensions may not be compatible with // SPDY. We should omit them from a Sec-WebSocket-Extension header. std::string name; @@ -425,28 +444,34 @@ bool WebSocketHandshakeResponseHandler::ParseResponseInfo( if (!response_info.headers.get()) return false; + // TODO(ricea): Eliminate all the reallocations and string copies. std::string response_message; response_message = response_info.headers->GetStatusLine(); response_message += "\r\n"; if (protocol_version_ >= kMinVersionOfHybiNewHandshake) - response_message += "Upgrade: websocket\r\n"; + AppendHeader(websockets::kUpgrade, + websockets::kWebSocketLowercase, + &response_message); else - response_message += "Upgrade: WebSocket\r\n"; - response_message += "Connection: Upgrade\r\n"; + AppendHeader(websockets::kUpgrade, "WebSocket", &response_message); + AppendHeader( + HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message); if (protocol_version_ >= kMinVersionOfHybiNewHandshake) { - std::string hash = base::SHA1HashString(challenge + kWebSocketGuid); + std::string hash = + base::SHA1HashString(challenge + websockets::kWebSocketGuid); std::string websocket_accept; bool encode_success = base::Base64Encode(hash, &websocket_accept); DCHECK(encode_success); - response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n"; + AppendHeader( + websockets::kSecWebSocketAccept, websocket_accept, &response_message); } void* iter = NULL; std::string name; std::string value; while (response_info.headers->EnumerateHeaderLines(&iter, &name, &value)) { - response_message += name + ": " + value + "\r\n"; + AppendHeader(name, value, &response_message); } response_message += "\r\n"; @@ -473,17 +498,22 @@ bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock( status = headers.find(":status"); if (status == headers.end()) return false; - std::string response_message; - response_message = - base::StringPrintf("%s%s\r\n", "HTTP/1.1 ", status->second.c_str()); - response_message += "Upgrade: websocket\r\n"; - response_message += "Connection: Upgrade\r\n"; - std::string hash = base::SHA1HashString(challenge + kWebSocketGuid); + std::string hash = + base::SHA1HashString(challenge + websockets::kWebSocketGuid); std::string websocket_accept; bool encode_success = base::Base64Encode(hash, &websocket_accept); DCHECK(encode_success); - response_message += "Sec-WebSocket-Accept: " + websocket_accept + "\r\n"; + + std::string response_message = base::StringPrintf( + "%s %s\r\n", websockets::kHttpProtocolVersion, status->second.c_str()); + + AppendHeader( + websockets::kUpgrade, websockets::kWebSocketLowercase, &response_message); + AppendHeader( + HttpRequestHeaders::kConnection, websockets::kUpgrade, &response_message); + AppendHeader( + websockets::kSecWebSocketAccept, websocket_accept, &response_message); for (SpdyHeaderBlock::const_iterator iter = headers.begin(); iter != headers.end(); @@ -510,11 +540,13 @@ bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock( else tval = value.substr(start); if (spdy_protocol_version >= 3 && - (LowerCaseEqualsASCII(iter->first, ":sec-websocket-protocol") || - LowerCaseEqualsASCII(iter->first, ":sec-websocket-extensions"))) - response_message += iter->first.substr(1) + ": " + tval + "\r\n"; + (LowerCaseEqualsASCII(iter->first, + websockets::kSecWebSocketProtocolSpdy3) || + LowerCaseEqualsASCII(iter->first, + websockets::kSecWebSocketExtensionsSpdy3))) + AppendHeader(iter->first.substr(1), tval, &response_message); else - response_message += iter->first + ": " + tval + "\r\n"; + AppendHeader(iter->first, tval, &response_message); start = end + 1; } while (end != std::string::npos); } diff --git a/chromium/net/websockets/websocket_handshake_handler_spdy_unittest.cc b/chromium/net/websockets/websocket_handshake_handler_spdy_test.cc index a8276cd220a..ebab9a8cf59 100644 --- a/chromium/net/websockets/websocket_handshake_handler_spdy_unittest.cc +++ b/chromium/net/websockets/websocket_handshake_handler_spdy_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/chromium/net/websockets/websocket_handshake_handler_unittest.cc b/chromium/net/websockets/websocket_handshake_handler_test.cc index e7d2d75ae21..4c1b15578ee 100644 --- a/chromium/net/websockets/websocket_handshake_handler_unittest.cc +++ b/chromium/net/websockets/websocket_handshake_handler_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/chromium/net/websockets/websocket_job_unittest.cc b/chromium/net/websockets/websocket_job_test.cc index 434796dbcbc..434796dbcbc 100644 --- a/chromium/net/websockets/websocket_job_unittest.cc +++ b/chromium/net/websockets/websocket_job_test.cc diff --git a/chromium/net/websockets/websocket_net_log_params_unittest.cc b/chromium/net/websockets/websocket_net_log_params_test.cc index b1c98570402..4690fd66964 100644 --- a/chromium/net/websockets/websocket_net_log_params_unittest.cc +++ b/chromium/net/websockets/websocket_net_log_params_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. diff --git a/chromium/net/websockets/websocket_stream.h b/chromium/net/websockets/websocket_stream.h index 6c2a4ff8ef1..4885bbe7295 100644 --- a/chromium/net/websockets/websocket_stream.h +++ b/chromium/net/websockets/websocket_stream.h @@ -92,7 +92,7 @@ class NET_EXPORT_PRIVATE WebSocketStream : public WebSocketStreamBase { // Reads WebSocket frame data. This operation finishes when new frame data // becomes available. Each frame message might be chopped off in the middle - // as specified in the description of WebSocketFrameChunk struct. + // as specified in the description of the WebSocketFrameChunk struct. // |frame_chunks| remains owned by the caller and must be valid until the // operation completes or Close() is called. |frame_chunks| must be empty on // calling. @@ -127,19 +127,22 @@ class NET_EXPORT_PRIVATE WebSocketStream : public WebSocketStreamBase { virtual int ReadFrames(ScopedVector<WebSocketFrameChunk>* frame_chunks, const CompletionCallback& callback) = 0; - // Writes WebSocket frame data. |frame_chunks| must obey the rule specified - // in the documentation of WebSocketFrameChunk struct: the first chunk of - // a WebSocket frame must contain non-NULL |header|, and the last chunk must - // have |final_chunk| field set to true. Series of chunks representing a - // WebSocket frame must be consistent (the total length of |data| fields must - // match |header->payload_length|). |frame_chunks| must be valid until the - // operation completes or Close() is called. + // Writes WebSocket frame data. |frame_chunks| must only contain complete + // frames. Every chunk must have a non-NULL |header| and the |final_chunk| + // boolean set to true. // - // This function should not be called while previous call of WriteFrames() is - // still pending. + // The |frame_chunks| pointer must remain valid until the operation completes + // or Close() is called. WriteFrames() will modify the contents of + // |frame_chunks| in the process of sending the message. After WriteFrames() + // has completed it is safe to clear and then re-use the vector, but other + // than that the caller should make no assumptions about its contents. // - // Support for incomplete frames is not guaranteed and may be removed from - // future iterations of the API. + // This function should not be called while a previous call to WriteFrames() + // on the same stream is pending. + // + // Frame boundaries may not be preserved. Frames may be split or + // coalesced. Message boundaries are preserved (as required by WebSocket API + // semantics). // // This method will only return OK if all frames were written completely. // Otherwise it will return an appropriate net error code. diff --git a/chromium/net/websockets/websocket_throttle_unittest.cc b/chromium/net/websockets/websocket_throttle_test.cc index fbd89caf9b7..14237b9b265 100644 --- a/chromium/net/websockets/websocket_throttle_unittest.cc +++ b/chromium/net/websockets/websocket_throttle_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. |