diff --git a/transport/internet/tcp/dialer.go b/transport/internet/tcp/dialer.go index 8d9bbf78..15ff4248 100644 --- a/transport/internet/tcp/dialer.go +++ b/transport/internet/tcp/dialer.go @@ -33,7 +33,9 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if IsFromMitm(tlsConfig.ServerName) { tlsConfig.ServerName = mitmServerName } - if r, ok := tlsConfig.Rand.(*tls.RandCarrier); ok && len(r.VerifyPeerCertInNames) > 0 && IsFromMitm(r.VerifyPeerCertInNames[0]) { + r, ok := tlsConfig.Rand.(*tls.RandCarrier) + isFromMitmVerify := ok && len(r.VerifyPeerCertInNames) > 0 && IsFromMitm(r.VerifyPeerCertInNames[0]) + if isFromMitmVerify { r.VerifyPeerCertInNames = r.VerifyPeerCertInNames[1:] after := mitmServerName for { @@ -46,29 +48,34 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } } } + isFromMitmAlpn := len(tlsConfig.NextProtos) == 1 && IsFromMitm(tlsConfig.NextProtos[0]) + if isFromMitmAlpn { + if mitmAlpn11 { + tlsConfig.NextProtos[0] = "http/1.1" + } else { + tlsConfig.NextProtos = nil + } + } if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil { conn = tls.UClient(conn, tlsConfig, fingerprint) - if len(tlsConfig.NextProtos) == 1 && (tlsConfig.NextProtos[0] == "http/1.1" || (IsFromMitm(tlsConfig.NextProtos[0]) && mitmAlpn11)) { - if err := conn.(*tls.UConn).WebsocketHandshakeContext(ctx); err != nil { - return nil, err - } + if len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "http/1.1" { // allow manually specify + err = conn.(*tls.UConn).WebsocketHandshakeContext(ctx) } else { - if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil { - return nil, err - } + err = conn.(*tls.UConn).HandshakeContext(ctx) } } else { - if len(tlsConfig.NextProtos) == 1 && IsFromMitm(tlsConfig.NextProtos[0]) { - if mitmAlpn11 { - tlsConfig.NextProtos[0] = "http/1.1" - } else { - tlsConfig.NextProtos = nil - } - } conn = tls.Client(conn, tlsConfig) - if err := conn.(*tls.Conn).HandshakeContext(ctx); err != nil { - return nil, err + err = conn.(*tls.Conn).HandshakeContext(ctx) + } + if err != nil { + if isFromMitmVerify { + return nil, errors.New("MITM: failed to verify " + mitmServerName).Base(err).AtWarning() } + return nil, err + } + if isFromMitmAlpn && !mitmAlpn11 && conn.(tls.Interface).NegotiatedProtocol() == "http/1.1" { + conn.Close() + return nil, errors.New("MITM: received unexpected ALPN http/1.1 from " + mitmServerName).AtWarning() } } else if config := reality.ConfigFromStreamSettings(streamSettings); config != nil { if conn, err = reality.UClient(conn, config, ctx, dest); err != nil {