| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- // Copyright 2023 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package transport
- import (
- "context"
- "crypto/tls"
- "crypto/x509"
- "errors"
- "log"
- "net"
- "net/http"
- "net/url"
- "os"
- "strconv"
- "strings"
- "cloud.google.com/go/auth/internal"
- "cloud.google.com/go/auth/internal/transport/cert"
- "github.com/google/s2a-go"
- "github.com/google/s2a-go/fallback"
- "google.golang.org/grpc/credentials"
- )
- const (
- mTLSModeAlways = "always"
- mTLSModeNever = "never"
- mTLSModeAuto = "auto"
- // Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
- googleAPIUseS2AEnv = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
- googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
- googleAPIUseMTLS = "GOOGLE_API_USE_MTLS_ENDPOINT"
- googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"
- universeDomainPlaceholder = "UNIVERSE_DOMAIN"
- mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
- mtlsMDSKey = "/run/google-mds-mtls/client.key"
- )
- var (
- errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
- )
- // Options is a struct that is duplicated information from the individual
- // transport packages in order to avoid cyclic deps. It correlates 1:1 with
- // fields on httptransport.Options and grpctransport.Options.
- type Options struct {
- Endpoint string
- DefaultMTLSEndpoint string
- DefaultEndpointTemplate string
- ClientCertProvider cert.Provider
- Client *http.Client
- UniverseDomain string
- EnableDirectPath bool
- EnableDirectPathXds bool
- }
- // getUniverseDomain returns the default service domain for a given Cloud
- // universe.
- func (o *Options) getUniverseDomain() string {
- if o.UniverseDomain == "" {
- return internal.DefaultUniverseDomain
- }
- return o.UniverseDomain
- }
- // isUniverseDomainGDU returns true if the universe domain is the default Google
- // universe.
- func (o *Options) isUniverseDomainGDU() bool {
- return o.getUniverseDomain() == internal.DefaultUniverseDomain
- }
- // defaultEndpoint returns the DefaultEndpointTemplate merged with the
- // universe domain if the DefaultEndpointTemplate is set, otherwise returns an
- // empty string.
- func (o *Options) defaultEndpoint() string {
- if o.DefaultEndpointTemplate == "" {
- return ""
- }
- return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
- }
- // mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
- // default endpoint.
- func (o *Options) mergedEndpoint() (string, error) {
- defaultEndpoint := o.defaultEndpoint()
- u, err := url.Parse(fixScheme(defaultEndpoint))
- if err != nil {
- return "", err
- }
- return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
- }
- func fixScheme(baseURL string) string {
- if !strings.Contains(baseURL, "://") {
- baseURL = "https://" + baseURL
- }
- return baseURL
- }
- // GetGRPCTransportCredsAndEndpoint returns an instance of
- // [google.golang.org/grpc/credentials.TransportCredentials], and the
- // corresponding endpoint to use for GRPC client.
- func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
- config, err := getTransportConfig(opts)
- if err != nil {
- return nil, "", err
- }
- defaultTransportCreds := credentials.NewTLS(&tls.Config{
- GetClientCertificate: config.clientCertSource,
- })
- var s2aAddr string
- var transportCredsForS2A credentials.TransportCredentials
- if config.mtlsS2AAddress != "" {
- s2aAddr = config.mtlsS2AAddress
- transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
- if err != nil {
- log.Printf("Loading MTLS MDS credentials failed: %v", err)
- if config.s2aAddress != "" {
- s2aAddr = config.s2aAddress
- } else {
- return defaultTransportCreds, config.endpoint, nil
- }
- }
- } else if config.s2aAddress != "" {
- s2aAddr = config.s2aAddress
- } else {
- return defaultTransportCreds, config.endpoint, nil
- }
- var fallbackOpts *s2a.FallbackOptions
- // In case of S2A failure, fall back to the endpoint that would've been used without S2A.
- if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
- fallbackOpts = &s2a.FallbackOptions{
- FallbackClientHandshakeFunc: fallbackHandshake,
- }
- }
- s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
- S2AAddress: s2aAddr,
- TransportCreds: transportCredsForS2A,
- FallbackOpts: fallbackOpts,
- })
- if err != nil {
- // Use default if we cannot initialize S2A client transport credentials.
- return defaultTransportCreds, config.endpoint, nil
- }
- return s2aTransportCreds, config.s2aMTLSEndpoint, nil
- }
- // GetHTTPTransportConfig returns a client certificate source and a function for
- // dialing MTLS with S2A.
- func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
- config, err := getTransportConfig(opts)
- if err != nil {
- return nil, nil, err
- }
- var s2aAddr string
- var transportCredsForS2A credentials.TransportCredentials
- if config.mtlsS2AAddress != "" {
- s2aAddr = config.mtlsS2AAddress
- transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
- if err != nil {
- log.Printf("Loading MTLS MDS credentials failed: %v", err)
- if config.s2aAddress != "" {
- s2aAddr = config.s2aAddress
- } else {
- return config.clientCertSource, nil, nil
- }
- }
- } else if config.s2aAddress != "" {
- s2aAddr = config.s2aAddress
- } else {
- return config.clientCertSource, nil, nil
- }
- var fallbackOpts *s2a.FallbackOptions
- // In case of S2A failure, fall back to the endpoint that would've been used without S2A.
- if fallbackURL, err := url.Parse(config.endpoint); err == nil {
- if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
- fallbackOpts = &s2a.FallbackOptions{
- FallbackDialer: &s2a.FallbackDialer{
- Dialer: fallbackDialer,
- ServerAddr: fallbackServerAddr,
- },
- }
- }
- }
- dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
- S2AAddress: s2aAddr,
- TransportCreds: transportCredsForS2A,
- FallbackOpts: fallbackOpts,
- })
- return nil, dialTLSContextFunc, nil
- }
- func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
- rootPEM, err := os.ReadFile(mtlsMDSRootFile)
- if err != nil {
- return nil, err
- }
- caCertPool := x509.NewCertPool()
- ok := caCertPool.AppendCertsFromPEM(rootPEM)
- if !ok {
- return nil, errors.New("failed to load MTLS MDS root certificate")
- }
- // The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
- // followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
- // tls.X509KeyPair function as both the certificate chain and private key arguments.
- cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
- if err != nil {
- return nil, err
- }
- tlsConfig := tls.Config{
- RootCAs: caCertPool,
- Certificates: []tls.Certificate{cert},
- MinVersion: tls.VersionTLS13,
- }
- return credentials.NewTLS(&tlsConfig), nil
- }
- func getTransportConfig(opts *Options) (*transportConfig, error) {
- clientCertSource, err := GetClientCertificateProvider(opts)
- if err != nil {
- return nil, err
- }
- endpoint, err := getEndpoint(opts, clientCertSource)
- if err != nil {
- return nil, err
- }
- defaultTransportConfig := transportConfig{
- clientCertSource: clientCertSource,
- endpoint: endpoint,
- }
- if !shouldUseS2A(clientCertSource, opts) {
- return &defaultTransportConfig, nil
- }
- if !opts.isUniverseDomainGDU() {
- return nil, errUniverseNotSupportedMTLS
- }
- s2aAddress := GetS2AAddress()
- mtlsS2AAddress := GetMTLSS2AAddress()
- if s2aAddress == "" && mtlsS2AAddress == "" {
- return &defaultTransportConfig, nil
- }
- return &transportConfig{
- clientCertSource: clientCertSource,
- endpoint: endpoint,
- s2aAddress: s2aAddress,
- mtlsS2AAddress: mtlsS2AAddress,
- s2aMTLSEndpoint: opts.DefaultMTLSEndpoint,
- }, nil
- }
- // GetClientCertificateProvider returns a default client certificate source, if
- // not provided by the user.
- //
- // A nil default source can be returned if the source does not exist. Any exceptions
- // encountered while initializing the default source will be reported as client
- // error (ex. corrupt metadata file).
- func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
- if !isClientCertificateEnabled(opts) {
- return nil, nil
- } else if opts.ClientCertProvider != nil {
- return opts.ClientCertProvider, nil
- }
- return cert.DefaultProvider()
- }
- // isClientCertificateEnabled returns true by default for all GDU universe domain, unless explicitly overridden by env var
- func isClientCertificateEnabled(opts *Options) bool {
- if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
- // error as false is OK
- b, _ := strconv.ParseBool(value)
- return b
- }
- return opts.isUniverseDomainGDU()
- }
- type transportConfig struct {
- // The client certificate source.
- clientCertSource cert.Provider
- // The corresponding endpoint to use based on client certificate source.
- endpoint string
- // The plaintext S2A address if it can be used, otherwise an empty string.
- s2aAddress string
- // The MTLS S2A address if it can be used, otherwise an empty string.
- mtlsS2AAddress string
- // The MTLS endpoint to use with S2A.
- s2aMTLSEndpoint string
- }
- // getEndpoint returns the endpoint for the service, taking into account the
- // user-provided endpoint override "settings.Endpoint".
- //
- // If no endpoint override is specified, we will either return the default endpoint or
- // the default mTLS endpoint if a client certificate is available.
- //
- // You can override the default endpoint choice (mtls vs. regular) by setting the
- // GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
- //
- // If the endpoint override is an address (host:port) rather than full base
- // URL (ex. https://...), then the user-provided address will be merged into
- // the default endpoint. For example, WithEndpoint("myhost:8000") and
- // DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return "https://myhost:8080/bar/baz"
- func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
- if opts.Endpoint == "" {
- mtlsMode := getMTLSMode()
- if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
- if !opts.isUniverseDomainGDU() {
- return "", errUniverseNotSupportedMTLS
- }
- return opts.DefaultMTLSEndpoint, nil
- }
- return opts.defaultEndpoint(), nil
- }
- if strings.Contains(opts.Endpoint, "://") {
- // User passed in a full URL path, use it verbatim.
- return opts.Endpoint, nil
- }
- if opts.defaultEndpoint() == "" {
- // If DefaultEndpointTemplate is not configured,
- // use the user provided endpoint verbatim. This allows a naked
- // "host[:port]" URL to be used with GRPC Direct Path.
- return opts.Endpoint, nil
- }
- // Assume user-provided endpoint is host[:port], merge it with the default endpoint.
- return opts.mergedEndpoint()
- }
- func getMTLSMode() string {
- mode := os.Getenv(googleAPIUseMTLS)
- if mode == "" {
- mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
- }
- if mode == "" {
- return mTLSModeAuto
- }
- return strings.ToLower(mode)
- }
|