diff --git a/h2.go b/h2.go index 2f6342c7..7f559459 100644 --- a/h2.go +++ b/h2.go @@ -33,15 +33,28 @@ func (r *H2Transport) RoundTrip(_ *http.Request) (*http.Response, error) { if !strings.Contains(raddr, ":") { raddr += ":443" } + serverName := r.Host + if host, _, err := net.SplitHostPort(raddr); err == nil { + serverName = host + } rawServerTLS, err := dial("tcp", raddr) if err != nil { return nil, err } defer rawServerTLS.Close() + cfg := r.TLSConfig + if cfg != nil { + cfg = cfg.Clone() + } else { + cfg = &tls.Config{} + } // Ensure that we only advertise HTTP/2 as the accepted protocol. - r.TLSConfig.NextProtos = []string{http2.NextProtoTLS} + cfg.NextProtos = []string{http2.NextProtoTLS} + if cfg.ServerName == "" { + cfg.ServerName = serverName + } // Initiate TLS and check remote host name against certificate. - rawServerTLS = tls.Client(rawServerTLS, r.TLSConfig) + rawServerTLS = tls.Client(rawServerTLS, cfg) rawTLSConn, ok := rawServerTLS.(*tls.Conn) if !ok { return nil, errors.New("invalid TLS connection") @@ -49,8 +62,8 @@ func (r *H2Transport) RoundTrip(_ *http.Request) (*http.Response, error) { if err = rawTLSConn.HandshakeContext(context.Background()); err != nil { return nil, err } - if r.TLSConfig == nil || !r.TLSConfig.InsecureSkipVerify { - if err = rawTLSConn.VerifyHostname(raddr[:strings.LastIndex(raddr, ":")]); err != nil { + if !cfg.InsecureSkipVerify { + if err = rawTLSConn.VerifyHostname(serverName); err != nil { return nil, err } }