runtime.go 18 KB


  1. package templ
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "encoding/hex"
  7. "errors"
  8. "fmt"
  9. "html"
  10. "html/template"
  11. "io"
  12. "net/http"
  13. "os"
  14. "reflect"
  15. "runtime"
  16. "sort"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "time"
  21. "github.com/a-h/templ/safehtml"
  22. )
  23. // Types exposed by all components.
  24. // Component is the interface that all templates implement.
  25. type Component interface {
  26. // Render the template.
  27. Render(ctx context.Context, w io.Writer) error
  28. }
  29. // ComponentFunc converts a function that matches the Component interface's
  30. // Render method into a Component.
  31. type ComponentFunc func(ctx context.Context, w io.Writer) error
  32. // Render the template.
  33. func (cf ComponentFunc) Render(ctx context.Context, w io.Writer) error {
  34. return cf(ctx, w)
  35. }
  36. // WithNonce sets a CSP nonce on the context and returns it.
  37. func WithNonce(ctx context.Context, nonce string) context.Context {
  38. ctx, v := getContext(ctx)
  39. v.nonce = nonce
  40. return ctx
  41. }
  42. // GetNonce returns the CSP nonce value set with WithNonce, or an
  43. // empty string if none has been set.
  44. func GetNonce(ctx context.Context) (nonce string) {
  45. if ctx == nil {
  46. return ""
  47. }
  48. _, v := getContext(ctx)
  49. return v.nonce
  50. }
  51. func WithChildren(ctx context.Context, children Component) context.Context {
  52. ctx, v := getContext(ctx)
  53. v.children = &children
  54. return ctx
  55. }
  56. func ClearChildren(ctx context.Context) context.Context {
  57. _, v := getContext(ctx)
  58. v.children = nil
  59. return ctx
  60. }
  61. // NopComponent is a component that doesn't render anything.
  62. var NopComponent = ComponentFunc(func(ctx context.Context, w io.Writer) error { return nil })
  63. // GetChildren from the context.
  64. func GetChildren(ctx context.Context) Component {
  65. _, v := getContext(ctx)
  66. if v.children == nil {
  67. return NopComponent
  68. }
  69. return *v.children
  70. }
  71. // EscapeString escapes HTML text within templates.
  72. func EscapeString(s string) string {
  73. return html.EscapeString(s)
  74. }
  75. // Bool attribute value.
  76. func Bool(value bool) bool {
  77. return value
  78. }
  79. // Classes for CSS.
  80. // Supported types are string, ConstantCSSClass, ComponentCSSClass, map[string]bool.
  81. func Classes(classes ...any) CSSClasses {
  82. return CSSClasses(classes)
  83. }
  84. // CSSClasses is a slice of CSS classes.
  85. type CSSClasses []any
  86. // String returns the names of all CSS classes.
  87. func (classes CSSClasses) String() string {
  88. if len(classes) == 0 {
  89. return ""
  90. }
  91. cp := newCSSProcessor()
  92. for _, v := range classes {
  93. cp.Add(v)
  94. }
  95. return cp.String()
  96. }
  97. func newCSSProcessor() *cssProcessor {
  98. return &cssProcessor{
  99. classNameToEnabled: make(map[string]bool),
  100. }
  101. }
  102. type cssProcessor struct {
  103. classNameToEnabled map[string]bool
  104. orderedNames []string
  105. }
  106. func (cp *cssProcessor) Add(item any) {
  107. switch c := item.(type) {
  108. case []string:
  109. for _, className := range c {
  110. cp.AddClassName(className, true)
  111. }
  112. case string:
  113. cp.AddClassName(c, true)
  114. case ConstantCSSClass:
  115. cp.AddClassName(c.ClassName(), true)
  116. case ComponentCSSClass:
  117. cp.AddClassName(c.ClassName(), true)
  118. case map[string]bool:
  119. // In Go, map keys are iterated in a randomized order.
  120. // So the keys in the map must be sorted to produce consistent output.
  121. keys := make([]string, len(c))
  122. var i int
  123. for key := range c {
  124. keys[i] = key
  125. i++
  126. }
  127. sort.Strings(keys)
  128. for _, className := range keys {
  129. cp.AddClassName(className, c[className])
  130. }
  131. case []KeyValue[string, bool]:
  132. for _, kv := range c {
  133. cp.AddClassName(kv.Key, kv.Value)
  134. }
  135. case KeyValue[string, bool]:
  136. cp.AddClassName(c.Key, c.Value)
  137. case []KeyValue[CSSClass, bool]:
  138. for _, kv := range c {
  139. cp.AddClassName(kv.Key.ClassName(), kv.Value)
  140. }
  141. case KeyValue[CSSClass, bool]:
  142. cp.AddClassName(c.Key.ClassName(), c.Value)
  143. case CSSClasses:
  144. for _, item := range c {
  145. cp.Add(item)
  146. }
  147. case []CSSClass:
  148. for _, item := range c {
  149. cp.Add(item)
  150. }
  151. case func() CSSClass:
  152. cp.AddClassName(c().ClassName(), true)
  153. default:
  154. cp.AddClassName(unknownTypeClassName, true)
  155. }
  156. }
  157. func (cp *cssProcessor) AddClassName(className string, enabled bool) {
  158. cp.classNameToEnabled[className] = enabled
  159. cp.orderedNames = append(cp.orderedNames, className)
  160. }
  161. func (cp *cssProcessor) String() string {
  162. // Order the outputs according to how they were input, and remove disabled names.
  163. rendered := make(map[string]any, len(cp.classNameToEnabled))
  164. var names []string
  165. for _, name := range cp.orderedNames {
  166. if enabled := cp.classNameToEnabled[name]; !enabled {
  167. continue
  168. }
  169. if _, hasBeenRendered := rendered[name]; hasBeenRendered {
  170. continue
  171. }
  172. names = append(names, name)
  173. rendered[name] = struct{}{}
  174. }
  175. return strings.Join(names, " ")
  176. }
  177. // KeyValue is a key and value pair.
  178. type KeyValue[TKey comparable, TValue any] struct {
  179. Key TKey `json:"name"`
  180. Value TValue `json:"value"`
  181. }
  182. // KV creates a new key/value pair from the input key and value.
  183. func KV[TKey comparable, TValue any](key TKey, value TValue) KeyValue[TKey, TValue] {
  184. return KeyValue[TKey, TValue]{
  185. Key: key,
  186. Value: value,
  187. }
  188. }
  189. const unknownTypeClassName = "--templ-css-class-unknown-type"
  190. // Class returns a CSS class name.
  191. // Deprecated: use a string instead.
  192. func Class(name string) CSSClass {
  193. return SafeClass(name)
  194. }
  195. // SafeClass bypasses CSS class name validation.
  196. // Deprecated: use a string instead.
  197. func SafeClass(name string) CSSClass {
  198. return ConstantCSSClass(name)
  199. }
  200. // CSSClass provides a class name.
  201. type CSSClass interface {
  202. ClassName() string
  203. }
  204. // ConstantCSSClass is a string constant of a CSS class name.
  205. // Deprecated: use a string instead.
  206. type ConstantCSSClass string
  207. // ClassName of the CSS class.
  208. func (css ConstantCSSClass) ClassName() string {
  209. return string(css)
  210. }
  211. // ComponentCSSClass is a templ.CSS
  212. type ComponentCSSClass struct {
  213. // ID of the class, will be autogenerated.
  214. ID string
  215. // Definition of the CSS.
  216. Class SafeCSS
  217. }
  218. // ClassName of the CSS class.
  219. func (css ComponentCSSClass) ClassName() string {
  220. return css.ID
  221. }
  222. // CSSID calculates an ID.
  223. func CSSID(name string, css string) string {
  224. sum := sha256.Sum256([]byte(css))
  225. hp := hex.EncodeToString(sum[:])[0:4]
  226. // Benchmarking showed this was fastest, and with fewest allocations (1).
  227. // Using strings.Builder (2 allocs).
  228. // Using fmt.Sprintf (3 allocs).
  229. return name + "_" + hp
  230. }
  231. // NewCSSMiddleware creates HTTP middleware that renders a global stylesheet of ComponentCSSClass
  232. // CSS if the request path matches, or updates the HTTP context to ensure that any handlers that
  233. // use templ.Components skip rendering <style> elements for classes that are included in the global
  234. // stylesheet. By default, the stylesheet path is /styles/templ.css
  235. func NewCSSMiddleware(next http.Handler, classes ...CSSClass) CSSMiddleware {
  236. return CSSMiddleware{
  237. Path: "/styles/templ.css",
  238. CSSHandler: NewCSSHandler(classes...),
  239. Next: next,
  240. }
  241. }
  242. // CSSMiddleware renders a global stylesheet.
  243. type CSSMiddleware struct {
  244. Path string
  245. CSSHandler CSSHandler
  246. Next http.Handler
  247. }
  248. func (cssm CSSMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  249. if r.URL.Path == cssm.Path {
  250. cssm.CSSHandler.ServeHTTP(w, r)
  251. return
  252. }
  253. // Add registered classes to the context.
  254. ctx, v := getContext(r.Context())
  255. for _, c := range cssm.CSSHandler.Classes {
  256. v.addClass(c.ID)
  257. }
  258. // Serve the request. Templ components will use the updated context
  259. // to know to skip rendering <style> elements for any component CSS
  260. // classes that have been included in the global stylesheet.
  261. cssm.Next.ServeHTTP(w, r.WithContext(ctx))
  262. }
  263. // NewCSSHandler creates a handler that serves a stylesheet containing the CSS of the
  264. // classes passed in. This is used by the CSSMiddleware to provide global stylesheets
  265. // for templ components.
  266. func NewCSSHandler(classes ...CSSClass) CSSHandler {
  267. ccssc := make([]ComponentCSSClass, 0, len(classes))
  268. for _, c := range classes {
  269. ccss, ok := c.(ComponentCSSClass)
  270. if !ok {
  271. continue
  272. }
  273. ccssc = append(ccssc, ccss)
  274. }
  275. return CSSHandler{
  276. Classes: ccssc,
  277. }
  278. }
  279. // CSSHandler is a HTTP handler that serves CSS.
  280. type CSSHandler struct {
  281. Logger func(err error)
  282. Classes []ComponentCSSClass
  283. }
  284. func (cssh CSSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  285. w.Header().Set("Content-Type", "text/css")
  286. for _, c := range cssh.Classes {
  287. _, err := w.Write([]byte(c.Class))
  288. if err != nil && cssh.Logger != nil {
  289. cssh.Logger(err)
  290. }
  291. }
  292. }
  293. // RenderCSSItems renders the CSS to the writer, if the items haven't already been rendered.
  294. func RenderCSSItems(ctx context.Context, w io.Writer, classes ...any) (err error) {
  295. if len(classes) == 0 {
  296. return nil
  297. }
  298. _, v := getContext(ctx)
  299. sb := new(strings.Builder)
  300. renderCSSItemsToBuilder(sb, v, classes...)
  301. if sb.Len() > 0 {
  302. if _, err = io.WriteString(w, `<style type="text/css">`); err != nil {
  303. return err
  304. }
  305. if _, err = io.WriteString(w, sb.String()); err != nil {
  306. return err
  307. }
  308. if _, err = io.WriteString(w, `</style>`); err != nil {
  309. return err
  310. }
  311. }
  312. return nil
  313. }
  314. func renderCSSItemsToBuilder(sb *strings.Builder, v *contextValue, classes ...any) {
  315. for _, c := range classes {
  316. switch ccc := c.(type) {
  317. case ComponentCSSClass:
  318. if !v.hasClassBeenRendered(ccc.ID) {
  319. sb.WriteString(string(ccc.Class))
  320. v.addClass(ccc.ID)
  321. }
  322. case KeyValue[ComponentCSSClass, bool]:
  323. if !ccc.Value {
  324. continue
  325. }
  326. renderCSSItemsToBuilder(sb, v, ccc.Key)
  327. case KeyValue[CSSClass, bool]:
  328. if !ccc.Value {
  329. continue
  330. }
  331. renderCSSItemsToBuilder(sb, v, ccc.Key)
  332. case CSSClasses:
  333. renderCSSItemsToBuilder(sb, v, ccc...)
  334. case []CSSClass:
  335. for _, item := range ccc {
  336. renderCSSItemsToBuilder(sb, v, item)
  337. }
  338. case func() CSSClass:
  339. renderCSSItemsToBuilder(sb, v, ccc())
  340. case []string:
  341. // Skip. These are class names, not CSS classes.
  342. case string:
  343. // Skip. This is a class name, not a CSS class.
  344. case ConstantCSSClass:
  345. // Skip. This is a class name, not a CSS class.
  346. case CSSClass:
  347. // Skip. This is a class name, not a CSS class.
  348. case map[string]bool:
  349. // Skip. These are class names, not CSS classes.
  350. case KeyValue[string, bool]:
  351. // Skip. These are class names, not CSS classes.
  352. case []KeyValue[string, bool]:
  353. // Skip. These are class names, not CSS classes.
  354. case KeyValue[ConstantCSSClass, bool]:
  355. // Skip. These are class names, not CSS classes.
  356. case []KeyValue[ConstantCSSClass, bool]:
  357. // Skip. These are class names, not CSS classes.
  358. }
  359. }
  360. }
  361. // SafeCSS is CSS that has been sanitized.
  362. type SafeCSS string
  363. type SafeCSSProperty string
  364. var safeCSSPropertyType = reflect.TypeOf(SafeCSSProperty(""))
  365. // SanitizeCSS sanitizes CSS properties to ensure that they are safe.
  366. func SanitizeCSS[T ~string](property string, value T) SafeCSS {
  367. if reflect.TypeOf(value) == safeCSSPropertyType {
  368. return SafeCSS(safehtml.SanitizeCSSProperty(property) + ":" + string(value) + ";")
  369. }
  370. p, v := safehtml.SanitizeCSS(property, string(value))
  371. return SafeCSS(p + ":" + v + ";")
  372. }
  373. // Attributes is an alias to map[string]any made for spread attributes.
  374. type Attributes map[string]any
  375. // sortedKeys returns the keys of a map in sorted order.
  376. func sortedKeys(m map[string]any) (keys []string) {
  377. keys = make([]string, len(m))
  378. var i int
  379. for k := range m {
  380. keys[i] = k
  381. i++
  382. }
  383. sort.Strings(keys)
  384. return keys
  385. }
  386. func writeStrings(w io.Writer, ss ...string) (err error) {
  387. for _, s := range ss {
  388. if _, err = io.WriteString(w, s); err != nil {
  389. return err
  390. }
  391. }
  392. return nil
  393. }
  394. func RenderAttributes(ctx context.Context, w io.Writer, attributes Attributes) (err error) {
  395. for _, key := range sortedKeys(attributes) {
  396. value := attributes[key]
  397. switch value := value.(type) {
  398. case string:
  399. if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(value), `"`); err != nil {
  400. return err
  401. }
  402. case *string:
  403. if value != nil {
  404. if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(*value), `"`); err != nil {
  405. return err
  406. }
  407. }
  408. case bool:
  409. if value {
  410. if err = writeStrings(w, ` `, EscapeString(key)); err != nil {
  411. return err
  412. }
  413. }
  414. case *bool:
  415. if value != nil && *value {
  416. if err = writeStrings(w, ` `, EscapeString(key)); err != nil {
  417. return err
  418. }
  419. }
  420. case KeyValue[string, bool]:
  421. if value.Value {
  422. if err = writeStrings(w, ` `, EscapeString(key), `="`, EscapeString(value.Key), `"`); err != nil {
  423. return err
  424. }
  425. }
  426. case KeyValue[bool, bool]:
  427. if value.Value && value.Key {
  428. if err = writeStrings(w, ` `, EscapeString(key)); err != nil {
  429. return err
  430. }
  431. }
  432. case func() bool:
  433. if value() {
  434. if err = writeStrings(w, ` `, EscapeString(key)); err != nil {
  435. return err
  436. }
  437. }
  438. }
  439. }
  440. return nil
  441. }
  442. // Context.
  443. type contextKeyType int
  444. const contextKey = contextKeyType(0)
  445. type contextValue struct {
  446. ss map[string]struct{}
  447. onceHandles map[*OnceHandle]struct{}
  448. children *Component
  449. nonce string
  450. }
  451. func (v *contextValue) setHasBeenRendered(h *OnceHandle) {
  452. if v.onceHandles == nil {
  453. v.onceHandles = map[*OnceHandle]struct{}{}
  454. }
  455. v.onceHandles[h] = struct{}{}
  456. }
  457. func (v *contextValue) getHasBeenRendered(h *OnceHandle) (ok bool) {
  458. if v.onceHandles == nil {
  459. v.onceHandles = map[*OnceHandle]struct{}{}
  460. }
  461. _, ok = v.onceHandles[h]
  462. return
  463. }
  464. func (v *contextValue) addScript(s string) {
  465. if v.ss == nil {
  466. v.ss = map[string]struct{}{}
  467. }
  468. v.ss["script_"+s] = struct{}{}
  469. }
  470. func (v *contextValue) hasScriptBeenRendered(s string) (ok bool) {
  471. if v.ss == nil {
  472. v.ss = map[string]struct{}{}
  473. }
  474. _, ok = v.ss["script_"+s]
  475. return
  476. }
  477. func (v *contextValue) addClass(s string) {
  478. if v.ss == nil {
  479. v.ss = map[string]struct{}{}
  480. }
  481. v.ss["class_"+s] = struct{}{}
  482. }
  483. func (v *contextValue) hasClassBeenRendered(s string) (ok bool) {
  484. if v.ss == nil {
  485. v.ss = map[string]struct{}{}
  486. }
  487. _, ok = v.ss["class_"+s]
  488. return
  489. }
  490. // InitializeContext initializes context used to store internal state used during rendering.
  491. func InitializeContext(ctx context.Context) context.Context {
  492. if _, ok := ctx.Value(contextKey).(*contextValue); ok {
  493. return ctx
  494. }
  495. v := &contextValue{}
  496. ctx = context.WithValue(ctx, contextKey, v)
  497. return ctx
  498. }
  499. func getContext(ctx context.Context) (context.Context, *contextValue) {
  500. v, ok := ctx.Value(contextKey).(*contextValue)
  501. if !ok {
  502. ctx = InitializeContext(ctx)
  503. v = ctx.Value(contextKey).(*contextValue)
  504. }
  505. return ctx, v
  506. }
  507. var bufferPool = sync.Pool{
  508. New: func() any {
  509. return new(bytes.Buffer)
  510. },
  511. }
  512. func GetBuffer() *bytes.Buffer {
  513. return bufferPool.Get().(*bytes.Buffer)
  514. }
  515. func ReleaseBuffer(b *bytes.Buffer) {
  516. b.Reset()
  517. bufferPool.Put(b)
  518. }
  519. // JoinStringErrs joins an optional list of errors.
  520. func JoinStringErrs(s string, errs ...error) (string, error) {
  521. return s, errors.Join(errs...)
  522. }
  523. // Error returned during template rendering.
  524. type Error struct {
  525. Err error
  526. // FileName of the template file.
  527. FileName string
  528. // Line index of the error.
  529. Line int
  530. // Col index of the error.
  531. Col int
  532. }
  533. func (e Error) Error() string {
  534. if e.FileName == "" {
  535. e.FileName = "templ"
  536. }
  537. return fmt.Sprintf("%s: error at line %d, col %d: %v", e.FileName, e.Line, e.Col, e.Err)
  538. }
  539. func (e Error) Unwrap() error {
  540. return e.Err
  541. }
  542. // Raw renders the input HTML to the output without applying HTML escaping.
  543. //
  544. // Use of this component presents a security risk - the HTML should come from
  545. // a trusted source, because it will be included as-is in the output.
  546. func Raw[T ~string](html T, errs ...error) Component {
  547. return ComponentFunc(func(ctx context.Context, w io.Writer) (err error) {
  548. if err = errors.Join(errs...); err != nil {
  549. return err
  550. }
  551. _, err = io.WriteString(w, string(html))
  552. return err
  553. })
  554. }
  555. // FromGoHTML creates a templ Component from a Go html/template template.
  556. func FromGoHTML(t *template.Template, data any) Component {
  557. return ComponentFunc(func(ctx context.Context, w io.Writer) (err error) {
  558. return t.Execute(w, data)
  559. })
  560. }
  561. // ToGoHTML renders the component to a Go html/template template.HTML string.
  562. func ToGoHTML(ctx context.Context, c Component) (s template.HTML, err error) {
  563. b := GetBuffer()
  564. defer ReleaseBuffer(b)
  565. if err = c.Render(ctx, b); err != nil {
  566. return
  567. }
  568. s = template.HTML(b.String())
  569. return
  570. }
  571. // WriteWatchModeString is used when rendering templates in development mode.
  572. // the generator would have written non-go code to the _templ.txt file, which
  573. // is then read by this function and written to the output.
  574. func WriteWatchModeString(w io.Writer, lineNum int) error {
  575. _, path, _, _ := runtime.Caller(1)
  576. if !strings.HasSuffix(path, "_templ.go") {
  577. return errors.New("templ: WriteWatchModeString can only be called from _templ.go")
  578. }
  579. txtFilePath := strings.Replace(path, "_templ.go", "_templ.txt", 1)
  580. literals, err := getWatchedStrings(txtFilePath)
  581. if err != nil {
  582. return fmt.Errorf("templ: failed to cache strings: %w", err)
  583. }
  584. if lineNum > len(literals) {
  585. return errors.New("templ: failed to find line " + strconv.Itoa(lineNum) + " in " + txtFilePath)
  586. }
  587. unquoted, err := strconv.Unquote(`"` + literals[lineNum-1] + `"`)
  588. if err != nil {
  589. return err
  590. }
  591. _, err = io.WriteString(w, unquoted)
  592. return err
  593. }
  594. var (
  595. watchModeCache = map[string]watchState{}
  596. watchStateMutex sync.Mutex
  597. )
  598. type watchState struct {
  599. modTime time.Time
  600. strings []string
  601. }
  602. func getWatchedStrings(txtFilePath string) ([]string, error) {
  603. watchStateMutex.Lock()
  604. defer watchStateMutex.Unlock()
  605. state, cached := watchModeCache[txtFilePath]
  606. if !cached {
  607. return cacheStrings(txtFilePath)
  608. }
  609. if time.Since(state.modTime) < time.Millisecond*100 {
  610. return state.strings, nil
  611. }
  612. info, err := os.Stat(txtFilePath)
  613. if err != nil {
  614. return nil, fmt.Errorf("templ: failed to stat %s: %w", txtFilePath, err)
  615. }
  616. if !info.ModTime().After(state.modTime) {
  617. return state.strings, nil
  618. }
  619. return cacheStrings(txtFilePath)
  620. }
  621. func cacheStrings(txtFilePath string) ([]string, error) {
  622. txtFile, err := os.Open(txtFilePath)
  623. if err != nil {
  624. return nil, fmt.Errorf("templ: failed to open %s: %w", txtFilePath, err)
  625. }
  626. defer txtFile.Close()
  627. info, err := txtFile.Stat()
  628. if err != nil {
  629. return nil, fmt.Errorf("templ: failed to stat %s: %w", txtFilePath, err)
  630. }
  631. all, err := io.ReadAll(txtFile)
  632. if err != nil {
  633. return nil, fmt.Errorf("templ: failed to read %s: %w", txtFilePath, err)
  634. }
  635. literals := strings.Split(string(all), "\n")
  636. watchModeCache[txtFilePath] = watchState{
  637. modTime: info.ModTime(),
  638. strings: literals,
  639. }
  640. return literals, nil
  641. }