forked from jriou/coller
178 lines
4.9 KiB
Go
178 lines
4.9 KiB
Go
package server
|
|
|
|
import (
|
|
"embed"
|
|
"encoding/json"
|
|
"fmt"
|
|
"html/template"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
|
|
"git.riou.xyz/jriou/coller/internal"
|
|
)
|
|
|
|
var (
|
|
encryptionKeyLength = internal.MIN_ENCRYPTION_KEY_LENGTH
|
|
supportedOSes = []string{"linux", "darwin"}
|
|
supportedArches = []string{"amd64", "arm64"}
|
|
supportedClients = []string{"coller", "copier"}
|
|
)
|
|
|
|
type Server struct {
|
|
logger *slog.Logger
|
|
db *Database
|
|
config *Config
|
|
version string
|
|
metrics *Metrics
|
|
}
|
|
|
|
func NewServer(logger *slog.Logger, db *Database, config *Config, version string) (*Server, error) {
|
|
l := logger.With("module", "server")
|
|
|
|
return &Server{
|
|
logger: l,
|
|
db: db,
|
|
config: config,
|
|
version: version,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) SetEncryptionKeyLength(length int) {
|
|
encryptionKeyLength = length
|
|
}
|
|
|
|
func (s *Server) SetMetrics(metrics *Metrics) {
|
|
s.metrics = metrics
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Message string `json:"message"`
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
func (e ErrorResponse) ToJSON() string {
|
|
b, err := json.Marshal(e)
|
|
if err == nil {
|
|
return string(b)
|
|
}
|
|
return fmt.Sprintf("{\"message\":\"could not serialize response to JSON\", \"error\":\"%v\"}", err)
|
|
}
|
|
|
|
func WriteError(w http.ResponseWriter, message string, err error) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
fmt.Fprint(w, ErrorResponse{
|
|
Message: message,
|
|
Error: fmt.Sprintf("%v", err),
|
|
}.ToJSON())
|
|
}
|
|
|
|
type GetProtectedWebNoteHandler struct {
|
|
Templates *template.Template
|
|
PageData PageData
|
|
logger *slog.Logger
|
|
db *Database
|
|
}
|
|
|
|
type ClientsHandler struct {
|
|
Templates *template.Template
|
|
PageData PageData
|
|
logger *slog.Logger
|
|
}
|
|
|
|
//go:embed templates/*
|
|
var templatesFS embed.FS
|
|
|
|
func (s *Server) Start() error {
|
|
r := mux.NewRouter().StrictSlash(true)
|
|
|
|
// Healthchecks
|
|
r.HandleFunc("/health", HealthHandler)
|
|
|
|
// Metrics
|
|
if s.metrics != nil && s.metrics.reg != nil {
|
|
r.Path(s.config.PrometheusRoute).Handler(promhttp.HandlerFor(s.metrics.reg, promhttp.HandlerOpts{Registry: s.metrics.reg})).Methods("GET")
|
|
}
|
|
|
|
// API
|
|
r.Path("/api/note").Handler(&CreateNoteHandler{logger: s.logger, db: s.db, maxUploadSize: s.config.MaxUploadSize}).Methods("POST")
|
|
r.Path("/{id:[a-zA-Z0-9]+}/{encryptionKey:[a-zA-Z0-9]+}").Handler(&GetProtectedNoteHandler{logger: s.logger, db: s.db}).Methods("GET")
|
|
r.Path("/{id:[a-zA-Z0-9]+}").Handler(&GetNoteHandler{logger: s.logger, db: s.db}).Methods("GET")
|
|
|
|
// Web pages
|
|
funcs := template.FuncMap{
|
|
"HumanDuration": internal.HumanDuration,
|
|
"TimeDiff": internal.TimeDiff,
|
|
"lower": strings.ToLower,
|
|
"string": func(b []byte) string { return string(b) },
|
|
}
|
|
p := PageData{
|
|
Title: s.config.Title,
|
|
Expirations: s.config.Expirations,
|
|
Expiration: s.config.Expiration,
|
|
Languages: s.config.Languages,
|
|
BootstrapDirectory: s.config.BootstrapDirectory,
|
|
}
|
|
|
|
if s.config.ShowVersion {
|
|
p.Version = s.version
|
|
}
|
|
p.EnableUploadFileButton = s.config.EnableUploadFileButton
|
|
|
|
templates, err := template.New("templates").Funcs(funcs).ParseFS(templatesFS, "templates/*.html")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
createNoteWithFormHandler := &CreateNoteWithFormHandler{
|
|
Templates: templates,
|
|
PageData: p,
|
|
logger: s.logger,
|
|
db: s.db,
|
|
maxUploadSize: s.config.MaxUploadSize,
|
|
}
|
|
r.Path("/create").Handler(createNoteWithFormHandler).Methods("POST")
|
|
|
|
clientsHandler := &ClientsHandler{
|
|
Templates: templates,
|
|
PageData: p,
|
|
logger: s.logger,
|
|
}
|
|
r.Path("/clients.html").Handler(clientsHandler).Methods("GET")
|
|
r.Path("/clients/{os:[a-z]+}-{arch:[a-z0-9]+}/{clientName:[a-z]+}").Handler(&ClientHandler{logger: s.logger, version: p.Version}).Methods("GET")
|
|
|
|
protectedWebNoteHandler := &GetProtectedWebNoteHandler{
|
|
Templates: templates,
|
|
PageData: p,
|
|
logger: s.logger,
|
|
db: s.db,
|
|
}
|
|
r.Path("/{id:[a-zA-Z0-9]+}/{encryptionKey:[a-zA-Z0-9]+}.html").Handler(protectedWebNoteHandler).Methods("GET")
|
|
|
|
webNoteHandler := &GetWebNoteHandler{
|
|
Templates: templates,
|
|
PageData: p,
|
|
logger: s.logger,
|
|
db: s.db,
|
|
}
|
|
r.Path("/{id:[a-zA-Z0-9]+}.html").Handler(webNoteHandler).Methods("GET")
|
|
|
|
if s.config.BootstrapDirectory != "" {
|
|
r.PathPrefix("/static/bootstrap/").Handler(http.StripPrefix("/static/bootstrap/", http.FileServer(http.Dir(s.config.BootstrapDirectory))))
|
|
}
|
|
|
|
r.Path("/").Handler(&HomeHandler{Templates: templates, PageData: p}).Methods("GET")
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.config.ListenAddress, s.config.ListenPort)
|
|
|
|
if s.config.HasTLS() {
|
|
s.logger.Info(fmt.Sprintf("listening to %s:%d (https)", s.config.ListenAddress, s.config.ListenPort))
|
|
return http.ListenAndServeTLS(addr, s.config.TLSCertFile, s.config.TLSKeyFile, r)
|
|
} else {
|
|
s.logger.Info(fmt.Sprintf("listening to %s:%d (http)", s.config.ListenAddress, s.config.ListenPort))
|
|
return http.ListenAndServe(addr, r)
|
|
}
|
|
}
|