From b45c3e32539f682547c4082f81c855ba3ff5287c Mon Sep 17 00:00:00 2001 From: Julien Riou Date: Tue, 26 Aug 2025 17:31:44 +0200 Subject: [PATCH] feat: Support Bearer token for clients Fixes #13. Signed-off-by: Julien Riou --- src/cmd/coller/main.go | 64 ++++++++++++++++++++++++----------------- src/cmd/collerd/main.go | 18 ++++-------- src/cmd/copier/main.go | 41 ++++++++++++++++++-------- src/internal/utils.go | 10 +++++++ 4 files changed, 83 insertions(+), 50 deletions(-) diff --git a/src/cmd/coller/main.go b/src/cmd/coller/main.go index d4de856..e951247 100644 --- a/src/cmd/coller/main.go +++ b/src/cmd/coller/main.go @@ -58,7 +58,7 @@ func handleMain() int { configFile := flag.String("config", filepath.Join(homeDir, ".config", AppName+".json"), "Configuration file") reconfigure := flag.Bool("reconfigure", false, "Re-create configuration file") url := flag.String("url", "", "URL of the coller API") - password := flag.String("password", os.Getenv("COPIER_PASSWORD"), "Password to decrypt the note") + password := flag.String("password", os.Getenv("COLLER_PASSWORD"), "Password to encrypt the note") askPassword := flag.Bool("ask-password", false, "Read password from input") noPassword := flag.Bool("no-password", false, "Allow notes without password") passwordLength := flag.Int("password-length", 16, "Length of the auto-generated password") @@ -66,6 +66,8 @@ func handleMain() int { expiration := flag.Int("expiration", 0, "Number of seconds before expiration") deleteAfterRead := flag.Bool("delete-after-read", false, "Delete the note after the first read") copier := flag.Bool("copier", false, "Print the copier command to decrypt the note") + bearer := flag.String("bearer", os.Getenv("COLLER_BEARER"), "Bearer token") + askBearer := flag.Bool("b", false, "Read bearer token from input") flag.Parse() @@ -94,7 +96,7 @@ func handleMain() int { logger.Debug("writing configuration file") err := WriteConfig(config, *configFile) if err != nil { - return ReturnError(logger, "could not create configuration file", err) + return internal.ReturnError(logger, "could not create configuration file", err) } } @@ -102,7 +104,7 @@ func handleMain() int { var config Config err = internal.ReadConfig(*configFile, &config) if err != nil { - return ReturnError(logger, "could not read configuration file", err) + return internal.ReturnError(logger, "could not read configuration file", err) } *url = config.URL @@ -113,12 +115,12 @@ func handleMain() int { logger.Debug("reading from file", slog.Any("file", *fileName)) content, err = os.ReadFile(*fileName) if err != nil { - return ReturnError(logger, "could not read from file", err) + return internal.ReturnError(logger, "could not read from file", err) } } else { err = clipboard.Init() if err != nil { - return ReturnError(logger, "could not initialize clipboard library", err) + return internal.ReturnError(logger, "could not initialize clipboard library", err) } content = clipboard.Read(clipboard.FmtText) } @@ -127,7 +129,7 @@ func handleMain() int { fmt.Print("Password: ") p, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { - return ReturnError(logger, "could not read password", err) + return internal.ReturnError(logger, "could not read password", err) } *password = string(p) fmt.Print("\n") @@ -136,13 +138,13 @@ func handleMain() int { if !*noPassword && *password == "" { logger.Debug("generating random password") if *passwordLength < internal.MIN_PASSWORD_LENGTH || *passwordLength > internal.MAX_PASSWORD_LENGTH { - return ReturnError(logger, "invalid password length for auto-generated password", fmt.Errorf("password length must be between %d and %d", internal.MIN_PASSWORD_LENGTH, internal.MAX_PASSWORD_LENGTH)) + return internal.ReturnError(logger, "invalid password length for auto-generated password", fmt.Errorf("password length must be between %d and %d", internal.MIN_PASSWORD_LENGTH, internal.MAX_PASSWORD_LENGTH)) } *password = internal.GenerateChars(*passwordLength) } if len(content) == 0 { - return ReturnError(logger, "could not create empty note", nil) + return internal.ReturnError(logger, "could not create empty note", nil) } p := NotePayload{} @@ -156,12 +158,12 @@ func handleMain() int { if *password != "" { logger.Debug("validating password") if err = internal.ValidatePassword(*password); err != nil { - return ReturnError(logger, "invalid password", nil) + return internal.ReturnError(logger, "invalid password", nil) } logger.Debug("encrypting content") content, err = internal.Encrypt(content, *password) if err != nil { - return ReturnError(logger, "could not encrypt note", err) + return internal.ReturnError(logger, "could not encrypt note", err) } p.Encrypted = true } @@ -172,32 +174,51 @@ func handleMain() int { payload, err := json.Marshal(p) if err != nil { - return ReturnError(logger, "could not serialize note to json", err) + return internal.ReturnError(logger, "could not serialize note to json", err) } apiRoute := *url + "/api/note" - logger.Debug("creating note", slog.Any("payload", payload), slog.Any("url", apiRoute)) + if *askBearer { + fmt.Print("Bearer: ") + b, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return internal.ReturnError(logger, "could not read bearer token", err) + } + *bearer = string(b) + fmt.Print("\n") + } - r, err := http.Post(apiRoute, "application/json", bytes.NewReader(payload)) + logger.Debug("creating http request") + req, err := http.NewRequest("POST", apiRoute, bytes.NewReader(payload)) if err != nil { - return ReturnError(logger, "could not create note", err) + return internal.ReturnError(logger, "could not create request", err) + } + + if *bearer != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *bearer)) + } + + logger.Debug("creating note", slog.Any("payload", payload), slog.Any("url", apiRoute)) + r, err := http.DefaultClient.Do(req) + if err != nil { + return internal.ReturnError(logger, "could not create note", err) } logger.Debug("reading response", slog.Any("response", r)) body, err := io.ReadAll(r.Body) if err != nil { - return ReturnError(logger, "could not read response", err) + return internal.ReturnError(logger, "could not read response", err) } jsonBody := &NoteResponse{} err = json.Unmarshal(body, jsonBody) if err != nil { - return ReturnError(logger, "could not decode response", err) + return internal.ReturnError(logger, "could not decode response", err) } if r.StatusCode != http.StatusOK { - return ReturnError(logger, jsonBody.Message, fmt.Errorf("%s", jsonBody.Error)) + return internal.ReturnError(logger, jsonBody.Message, fmt.Errorf("%s", jsonBody.Error)) } logger.Debug("finding note location") @@ -243,12 +264,3 @@ func WriteConfig(config Config, fileName string) error { return nil } - -func ReturnError(logger *slog.Logger, message string, err error) int { - if err != nil { - logger.Error(message, slog.Any("error", err)) - } else { - logger.Error(message) - } - return internal.RC_ERROR -} diff --git a/src/cmd/collerd/main.go b/src/cmd/collerd/main.go index 3145253..35ca6a2 100644 --- a/src/cmd/collerd/main.go +++ b/src/cmd/collerd/main.go @@ -50,27 +50,23 @@ func handleMain() int { if *configFileName != "" { err = internal.ReadConfig(*configFileName, config) if err != nil { - logger.Error("cannot parse configuration file", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not parse configuration file", err) } logger.Debug("configuration file parsed", slog.Any("file", *configFileName)) } if err = config.Check(); err != nil { - logger.Error("invalid configuration", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "invalid configuration", err) } db, err := server.NewDatabase(logger, config) if err != nil { - logger.Error("could not connect to the database", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not connect to the database", err) } srv, err := server.NewServer(logger, db, config, AppVersion) if err != nil { - logger.Error("could not create server", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not create server", err) } srv.SetIDLength(config.IDLength) @@ -80,16 +76,14 @@ func handleMain() int { reg := prometheus.NewRegistry() metrics, err := server.NewMetrics(logger, reg, config, db) if err != nil { - logger.Error("could not register metrics", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not register metrics", err) } srv.SetMetrics(metrics) } err = srv.Start() if err != nil { - logger.Error("could not start server", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not start server", err) } return internal.RC_OK diff --git a/src/cmd/copier/main.go b/src/cmd/copier/main.go index 7511394..611d146 100644 --- a/src/cmd/copier/main.go +++ b/src/cmd/copier/main.go @@ -28,9 +28,11 @@ func handleMain() int { quiet := flag.Bool("quiet", false, "Log errors only") verbose := flag.Bool("verbose", false, "Print more logs") debug := flag.Bool("debug", false, "Print even more logs") - password := flag.String("password", os.Getenv("COPIER_PASSWORD"), "Password to decrypt the note") + password := flag.String("password", os.Getenv("COLLER_PASSWORD"), "Password to decrypt the note") askPassword := flag.Bool("w", false, "Read password from input") fileName := flag.String("file", "", "Write content of the note to a file") + bearer := flag.String("bearer", os.Getenv("COLLER_BEARER"), "Bearer token") + askBearer := flag.Bool("b", false, "Read bearer token from input") flag.Parse() @@ -62,24 +64,41 @@ func handleMain() int { fmt.Print("Password: ") p, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { - logger.Error("could not read password", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not read password", err) } *password = string(p) + fmt.Print("\n") + } + + if *askBearer { + fmt.Print("Bearer: ") + b, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return internal.ReturnError(logger, "could not read bearer token", err) + } + *bearer = string(b) + fmt.Print("\n") + } + + logger.Debug("creating http request") + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return internal.ReturnError(logger, "could not create request", err) + } + if *bearer != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *bearer)) } logger.Debug("parsing url", slog.Any("url", url)) - r, err := http.Get(url) + r, err := http.DefaultClient.Do(req) if err != nil { - logger.Error("could not retreive note", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not retreive note", err) } logger.Debug("decoding body") body, err := io.ReadAll(r.Body) if err != nil { - logger.Error("could not read response", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not read response", err) } var content []byte @@ -87,8 +106,7 @@ func handleMain() int { logger.Debug("decrypting note") content, err = internal.Decrypt(body, *password) if err != nil { - logger.Error("could not decrypt paste", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not decrypt paste", err) } } else { content = body @@ -98,8 +116,7 @@ func handleMain() int { logger.Debug("writing output to file", slog.Any("file", *fileName)) err = os.WriteFile(*fileName, content, 0644) if err != nil { - logger.Error("could not write output to file", slog.Any("error", err)) - return internal.RC_ERROR + return internal.ReturnError(logger, "could not write output to file", err) } } else { fmt.Printf("%s", content) diff --git a/src/internal/utils.go b/src/internal/utils.go index 42f91cd..7844792 100644 --- a/src/internal/utils.go +++ b/src/internal/utils.go @@ -3,6 +3,7 @@ package internal import ( "encoding/json" "fmt" + "log/slog" "math/rand" "os" "path/filepath" @@ -99,3 +100,12 @@ func HumanDuration(i int) string { } return fmt.Sprintf("%d %s", i, w) } + +func ReturnError(logger *slog.Logger, message string, err error) int { + if err != nil { + logger.Error(message, slog.Any("error", err)) + } else { + logger.Error(message) + } + return RC_ERROR +}