cba.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package transport
  15. import (
  16. "context"
  17. "crypto/tls"
  18. "crypto/x509"
  19. "errors"
  20. "log"
  21. "net"
  22. "net/http"
  23. "net/url"
  24. "os"
  25. "strconv"
  26. "strings"
  27. "cloud.google.com/go/auth/internal"
  28. "cloud.google.com/go/auth/internal/transport/cert"
  29. "github.com/google/s2a-go"
  30. "github.com/google/s2a-go/fallback"
  31. "google.golang.org/grpc/credentials"
  32. )
  33. const (
  34. mTLSModeAlways = "always"
  35. mTLSModeNever = "never"
  36. mTLSModeAuto = "auto"
  37. // Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
  38. googleAPIUseS2AEnv = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
  39. googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
  40. googleAPIUseMTLS = "GOOGLE_API_USE_MTLS_ENDPOINT"
  41. googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"
  42. universeDomainPlaceholder = "UNIVERSE_DOMAIN"
  43. mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
  44. mtlsMDSKey = "/run/google-mds-mtls/client.key"
  45. )
  46. var (
  47. errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
  48. )
  49. // Options is a struct that is duplicated information from the individual
  50. // transport packages in order to avoid cyclic deps. It correlates 1:1 with
  51. // fields on httptransport.Options and grpctransport.Options.
  52. type Options struct {
  53. Endpoint string
  54. DefaultMTLSEndpoint string
  55. DefaultEndpointTemplate string
  56. ClientCertProvider cert.Provider
  57. Client *http.Client
  58. UniverseDomain string
  59. EnableDirectPath bool
  60. EnableDirectPathXds bool
  61. }
  62. // getUniverseDomain returns the default service domain for a given Cloud
  63. // universe.
  64. func (o *Options) getUniverseDomain() string {
  65. if o.UniverseDomain == "" {
  66. return internal.DefaultUniverseDomain
  67. }
  68. return o.UniverseDomain
  69. }
  70. // isUniverseDomainGDU returns true if the universe domain is the default Google
  71. // universe.
  72. func (o *Options) isUniverseDomainGDU() bool {
  73. return o.getUniverseDomain() == internal.DefaultUniverseDomain
  74. }
  75. // defaultEndpoint returns the DefaultEndpointTemplate merged with the
  76. // universe domain if the DefaultEndpointTemplate is set, otherwise returns an
  77. // empty string.
  78. func (o *Options) defaultEndpoint() string {
  79. if o.DefaultEndpointTemplate == "" {
  80. return ""
  81. }
  82. return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
  83. }
  84. // mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
  85. // default endpoint.
  86. func (o *Options) mergedEndpoint() (string, error) {
  87. defaultEndpoint := o.defaultEndpoint()
  88. u, err := url.Parse(fixScheme(defaultEndpoint))
  89. if err != nil {
  90. return "", err
  91. }
  92. return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
  93. }
  94. func fixScheme(baseURL string) string {
  95. if !strings.Contains(baseURL, "://") {
  96. baseURL = "https://" + baseURL
  97. }
  98. return baseURL
  99. }
  100. // GetGRPCTransportCredsAndEndpoint returns an instance of
  101. // [google.golang.org/grpc/credentials.TransportCredentials], and the
  102. // corresponding endpoint to use for GRPC client.
  103. func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
  104. config, err := getTransportConfig(opts)
  105. if err != nil {
  106. return nil, "", err
  107. }
  108. defaultTransportCreds := credentials.NewTLS(&tls.Config{
  109. GetClientCertificate: config.clientCertSource,
  110. })
  111. var s2aAddr string
  112. var transportCredsForS2A credentials.TransportCredentials
  113. if config.mtlsS2AAddress != "" {
  114. s2aAddr = config.mtlsS2AAddress
  115. transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
  116. if err != nil {
  117. log.Printf("Loading MTLS MDS credentials failed: %v", err)
  118. if config.s2aAddress != "" {
  119. s2aAddr = config.s2aAddress
  120. } else {
  121. return defaultTransportCreds, config.endpoint, nil
  122. }
  123. }
  124. } else if config.s2aAddress != "" {
  125. s2aAddr = config.s2aAddress
  126. } else {
  127. return defaultTransportCreds, config.endpoint, nil
  128. }
  129. var fallbackOpts *s2a.FallbackOptions
  130. // In case of S2A failure, fall back to the endpoint that would've been used without S2A.
  131. if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
  132. fallbackOpts = &s2a.FallbackOptions{
  133. FallbackClientHandshakeFunc: fallbackHandshake,
  134. }
  135. }
  136. s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
  137. S2AAddress: s2aAddr,
  138. TransportCreds: transportCredsForS2A,
  139. FallbackOpts: fallbackOpts,
  140. })
  141. if err != nil {
  142. // Use default if we cannot initialize S2A client transport credentials.
  143. return defaultTransportCreds, config.endpoint, nil
  144. }
  145. return s2aTransportCreds, config.s2aMTLSEndpoint, nil
  146. }
  147. // GetHTTPTransportConfig returns a client certificate source and a function for
  148. // dialing MTLS with S2A.
  149. func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
  150. config, err := getTransportConfig(opts)
  151. if err != nil {
  152. return nil, nil, err
  153. }
  154. var s2aAddr string
  155. var transportCredsForS2A credentials.TransportCredentials
  156. if config.mtlsS2AAddress != "" {
  157. s2aAddr = config.mtlsS2AAddress
  158. transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
  159. if err != nil {
  160. log.Printf("Loading MTLS MDS credentials failed: %v", err)
  161. if config.s2aAddress != "" {
  162. s2aAddr = config.s2aAddress
  163. } else {
  164. return config.clientCertSource, nil, nil
  165. }
  166. }
  167. } else if config.s2aAddress != "" {
  168. s2aAddr = config.s2aAddress
  169. } else {
  170. return config.clientCertSource, nil, nil
  171. }
  172. var fallbackOpts *s2a.FallbackOptions
  173. // In case of S2A failure, fall back to the endpoint that would've been used without S2A.
  174. if fallbackURL, err := url.Parse(config.endpoint); err == nil {
  175. if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
  176. fallbackOpts = &s2a.FallbackOptions{
  177. FallbackDialer: &s2a.FallbackDialer{
  178. Dialer: fallbackDialer,
  179. ServerAddr: fallbackServerAddr,
  180. },
  181. }
  182. }
  183. }
  184. dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
  185. S2AAddress: s2aAddr,
  186. TransportCreds: transportCredsForS2A,
  187. FallbackOpts: fallbackOpts,
  188. })
  189. return nil, dialTLSContextFunc, nil
  190. }
  191. func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
  192. rootPEM, err := os.ReadFile(mtlsMDSRootFile)
  193. if err != nil {
  194. return nil, err
  195. }
  196. caCertPool := x509.NewCertPool()
  197. ok := caCertPool.AppendCertsFromPEM(rootPEM)
  198. if !ok {
  199. return nil, errors.New("failed to load MTLS MDS root certificate")
  200. }
  201. // The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
  202. // followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
  203. // tls.X509KeyPair function as both the certificate chain and private key arguments.
  204. cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
  205. if err != nil {
  206. return nil, err
  207. }
  208. tlsConfig := tls.Config{
  209. RootCAs: caCertPool,
  210. Certificates: []tls.Certificate{cert},
  211. MinVersion: tls.VersionTLS13,
  212. }
  213. return credentials.NewTLS(&tlsConfig), nil
  214. }
  215. func getTransportConfig(opts *Options) (*transportConfig, error) {
  216. clientCertSource, err := GetClientCertificateProvider(opts)
  217. if err != nil {
  218. return nil, err
  219. }
  220. endpoint, err := getEndpoint(opts, clientCertSource)
  221. if err != nil {
  222. return nil, err
  223. }
  224. defaultTransportConfig := transportConfig{
  225. clientCertSource: clientCertSource,
  226. endpoint: endpoint,
  227. }
  228. if !shouldUseS2A(clientCertSource, opts) {
  229. return &defaultTransportConfig, nil
  230. }
  231. if !opts.isUniverseDomainGDU() {
  232. return nil, errUniverseNotSupportedMTLS
  233. }
  234. s2aAddress := GetS2AAddress()
  235. mtlsS2AAddress := GetMTLSS2AAddress()
  236. if s2aAddress == "" && mtlsS2AAddress == "" {
  237. return &defaultTransportConfig, nil
  238. }
  239. return &transportConfig{
  240. clientCertSource: clientCertSource,
  241. endpoint: endpoint,
  242. s2aAddress: s2aAddress,
  243. mtlsS2AAddress: mtlsS2AAddress,
  244. s2aMTLSEndpoint: opts.DefaultMTLSEndpoint,
  245. }, nil
  246. }
  247. // GetClientCertificateProvider returns a default client certificate source, if
  248. // not provided by the user.
  249. //
  250. // A nil default source can be returned if the source does not exist. Any exceptions
  251. // encountered while initializing the default source will be reported as client
  252. // error (ex. corrupt metadata file).
  253. func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
  254. if !isClientCertificateEnabled(opts) {
  255. return nil, nil
  256. } else if opts.ClientCertProvider != nil {
  257. return opts.ClientCertProvider, nil
  258. }
  259. return cert.DefaultProvider()
  260. }
  261. // isClientCertificateEnabled returns true by default for all GDU universe domain, unless explicitly overridden by env var
  262. func isClientCertificateEnabled(opts *Options) bool {
  263. if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
  264. // error as false is OK
  265. b, _ := strconv.ParseBool(value)
  266. return b
  267. }
  268. return opts.isUniverseDomainGDU()
  269. }
  270. type transportConfig struct {
  271. // The client certificate source.
  272. clientCertSource cert.Provider
  273. // The corresponding endpoint to use based on client certificate source.
  274. endpoint string
  275. // The plaintext S2A address if it can be used, otherwise an empty string.
  276. s2aAddress string
  277. // The MTLS S2A address if it can be used, otherwise an empty string.
  278. mtlsS2AAddress string
  279. // The MTLS endpoint to use with S2A.
  280. s2aMTLSEndpoint string
  281. }
  282. // getEndpoint returns the endpoint for the service, taking into account the
  283. // user-provided endpoint override "settings.Endpoint".
  284. //
  285. // If no endpoint override is specified, we will either return the default endpoint or
  286. // the default mTLS endpoint if a client certificate is available.
  287. //
  288. // You can override the default endpoint choice (mtls vs. regular) by setting the
  289. // GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
  290. //
  291. // If the endpoint override is an address (host:port) rather than full base
  292. // URL (ex. https://...), then the user-provided address will be merged into
  293. // the default endpoint. For example, WithEndpoint("myhost:8000") and
  294. // DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return "https://myhost:8080/bar/baz"
  295. func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
  296. if opts.Endpoint == "" {
  297. mtlsMode := getMTLSMode()
  298. if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
  299. if !opts.isUniverseDomainGDU() {
  300. return "", errUniverseNotSupportedMTLS
  301. }
  302. return opts.DefaultMTLSEndpoint, nil
  303. }
  304. return opts.defaultEndpoint(), nil
  305. }
  306. if strings.Contains(opts.Endpoint, "://") {
  307. // User passed in a full URL path, use it verbatim.
  308. return opts.Endpoint, nil
  309. }
  310. if opts.defaultEndpoint() == "" {
  311. // If DefaultEndpointTemplate is not configured,
  312. // use the user provided endpoint verbatim. This allows a naked
  313. // "host[:port]" URL to be used with GRPC Direct Path.
  314. return opts.Endpoint, nil
  315. }
  316. // Assume user-provided endpoint is host[:port], merge it with the default endpoint.
  317. return opts.mergedEndpoint()
  318. }
  319. func getMTLSMode() string {
  320. mode := os.Getenv(googleAPIUseMTLS)
  321. if mode == "" {
  322. mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
  323. }
  324. if mode == "" {
  325. return mTLSModeAuto
  326. }
  327. return strings.ToLower(mode)
  328. }