transport.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. // Copyright 2023 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package httptransport
  15. import (
  16. "context"
  17. "crypto/tls"
  18. "net"
  19. "net/http"
  20. "os"
  21. "time"
  22. "cloud.google.com/go/auth"
  23. "cloud.google.com/go/auth/credentials"
  24. "cloud.google.com/go/auth/internal"
  25. "cloud.google.com/go/auth/internal/transport"
  26. "cloud.google.com/go/auth/internal/transport/cert"
  27. "go.opencensus.io/plugin/ochttp"
  28. "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
  29. "golang.org/x/net/http2"
  30. )
  31. const (
  32. quotaProjectHeaderKey = "X-goog-user-project"
  33. )
  34. func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
  35. var headers = opts.Headers
  36. ht := &headerTransport{
  37. base: base,
  38. headers: headers,
  39. }
  40. var trans http.RoundTripper = ht
  41. // Give OpenTelemetry precedence over OpenCensus in case user configuration
  42. // causes both to write the same header (`X-Cloud-Trace-Context`).
  43. trans = addOpenTelemetryTransport(trans, opts)
  44. trans = addOCTransport(trans, opts)
  45. switch {
  46. case opts.DisableAuthentication:
  47. // Do nothing.
  48. case opts.APIKey != "":
  49. qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
  50. if qp != "" {
  51. if headers == nil {
  52. headers = make(map[string][]string, 1)
  53. }
  54. headers.Set(quotaProjectHeaderKey, qp)
  55. }
  56. trans = &apiKeyTransport{
  57. Transport: trans,
  58. Key: opts.APIKey,
  59. }
  60. default:
  61. var creds *auth.Credentials
  62. if opts.Credentials != nil {
  63. creds = opts.Credentials
  64. } else {
  65. var err error
  66. creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
  67. if err != nil {
  68. return nil, err
  69. }
  70. }
  71. qp, err := creds.QuotaProjectID(context.Background())
  72. if err != nil {
  73. return nil, err
  74. }
  75. if qp != "" {
  76. if headers == nil {
  77. headers = make(map[string][]string, 1)
  78. }
  79. // Don't overwrite user specified quota
  80. if v := headers.Get(quotaProjectHeaderKey); v == "" {
  81. headers.Set(quotaProjectHeaderKey, qp)
  82. }
  83. }
  84. var skipUD bool
  85. if iOpts := opts.InternalOptions; iOpts != nil {
  86. skipUD = iOpts.SkipUniverseDomainValidation
  87. }
  88. creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
  89. trans = &authTransport{
  90. base: trans,
  91. creds: creds,
  92. clientUniverseDomain: opts.UniverseDomain,
  93. skipUniverseDomainValidation: skipUD,
  94. }
  95. }
  96. return trans, nil
  97. }
  98. // defaultBaseTransport returns the base HTTP transport.
  99. // On App Engine, this is urlfetch.Transport.
  100. // Otherwise, use a default transport, taking most defaults from
  101. // http.DefaultTransport.
  102. // If TLSCertificate is available, set TLSClientConfig as well.
  103. func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
  104. defaultTransport, ok := http.DefaultTransport.(*http.Transport)
  105. if !ok {
  106. defaultTransport = transport.BaseTransport()
  107. }
  108. trans := defaultTransport.Clone()
  109. trans.MaxIdleConnsPerHost = 100
  110. if clientCertSource != nil {
  111. trans.TLSClientConfig = &tls.Config{
  112. GetClientCertificate: clientCertSource,
  113. }
  114. }
  115. if dialTLSContext != nil {
  116. // If DialTLSContext is set, TLSClientConfig wil be ignored
  117. trans.DialTLSContext = dialTLSContext
  118. }
  119. // Configures the ReadIdleTimeout HTTP/2 option for the
  120. // transport. This allows broken idle connections to be pruned more quickly,
  121. // preventing the client from attempting to re-use connections that will no
  122. // longer work.
  123. http2Trans, err := http2.ConfigureTransports(trans)
  124. if err == nil {
  125. http2Trans.ReadIdleTimeout = time.Second * 31
  126. }
  127. return trans
  128. }
  129. type apiKeyTransport struct {
  130. // Key is the API Key to set on requests.
  131. Key string
  132. // Transport is the underlying HTTP transport.
  133. // If nil, http.DefaultTransport is used.
  134. Transport http.RoundTripper
  135. }
  136. func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  137. newReq := *req
  138. args := newReq.URL.Query()
  139. args.Set("key", t.Key)
  140. newReq.URL.RawQuery = args.Encode()
  141. return t.Transport.RoundTrip(&newReq)
  142. }
  143. type headerTransport struct {
  144. headers http.Header
  145. base http.RoundTripper
  146. }
  147. func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  148. rt := t.base
  149. newReq := *req
  150. newReq.Header = make(http.Header)
  151. for k, vv := range req.Header {
  152. newReq.Header[k] = vv
  153. }
  154. for k, v := range t.headers {
  155. newReq.Header[k] = v
  156. }
  157. return rt.RoundTrip(&newReq)
  158. }
  159. func addOpenTelemetryTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
  160. if opts.DisableTelemetry {
  161. return trans
  162. }
  163. return otelhttp.NewTransport(trans)
  164. }
  165. func addOCTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
  166. if opts.DisableTelemetry {
  167. return trans
  168. }
  169. return &ochttp.Transport{
  170. Base: trans,
  171. Propagation: &httpFormat{},
  172. }
  173. }
  174. type authTransport struct {
  175. creds *auth.Credentials
  176. base http.RoundTripper
  177. clientUniverseDomain string
  178. skipUniverseDomainValidation bool
  179. }
  180. // getClientUniverseDomain returns the default service domain for a given Cloud
  181. // universe, with the following precedence:
  182. //
  183. // 1. A non-empty option.WithUniverseDomain or similar client option.
  184. // 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
  185. // 3. The default value "googleapis.com".
  186. //
  187. // This is the universe domain configured for the client, which will be compared
  188. // to the universe domain that is separately configured for the credentials.
  189. func (t *authTransport) getClientUniverseDomain() string {
  190. if t.clientUniverseDomain != "" {
  191. return t.clientUniverseDomain
  192. }
  193. if envUD := os.Getenv(internal.UniverseDomainEnvVar); envUD != "" {
  194. return envUD
  195. }
  196. return internal.DefaultUniverseDomain
  197. }
  198. // RoundTrip authorizes and authenticates the request with an
  199. // access token from Transport's Source. Per the RoundTripper contract we must
  200. // not modify the initial request, so we clone it, and we must close the body
  201. // on any errors that happens during our token logic.
  202. func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  203. reqBodyClosed := false
  204. if req.Body != nil {
  205. defer func() {
  206. if !reqBodyClosed {
  207. req.Body.Close()
  208. }
  209. }()
  210. }
  211. token, err := t.creds.Token(req.Context())
  212. if err != nil {
  213. return nil, err
  214. }
  215. if !t.skipUniverseDomainValidation && token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
  216. credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
  217. if err != nil {
  218. return nil, err
  219. }
  220. if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
  221. return nil, err
  222. }
  223. }
  224. req2 := req.Clone(req.Context())
  225. SetAuthHeader(token, req2)
  226. reqBodyClosed = true
  227. return t.base.RoundTrip(req2)
  228. }