coller/src/server/db.go
Julien Riou 7c00b364d1
docs: Add READMEs
Signed-off-by: Julien Riou <julien@riou.xyz>
2025-08-24 15:48:12 +02:00

151 lines
3.7 KiB
Go

package server
import (
"fmt"
"log/slog"
"slices"
"strings"
"time"
"git.riou.xyz/jriou/coller/internal"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
type Database struct {
logger *slog.Logger
db *gorm.DB
expirationInterval int
expirations []int
expiration int
}
var gconfig = &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
func NewDatabase(logger *slog.Logger, config *Config) (d *Database, err error) {
l := logger.With("module", "db")
logger.Debug("connecting to the database")
var db *gorm.DB
switch config.DatabaseType {
case "postgres":
db, err = gorm.Open(postgres.New(postgres.Config{DSN: config.DatabaseDsn}), gconfig)
case "sqlite":
db, err = gorm.Open(sqlite.Open(config.DatabaseDsn), gconfig)
default:
return nil, fmt.Errorf("database type '%s' not supported", config.DatabaseType)
}
if err != nil {
return nil, err
}
logger.Debug("connected to the database")
d = &Database{
logger: l,
db: db,
expirationInterval: config.ExpirationInterval,
expirations: config.Expirations,
expiration: config.Expiration,
}
if err = d.UpdateSchema(); err != nil {
return nil, err
}
go d.StartExpireThread()
return d, nil
}
func (d *Database) UpdateSchema() error {
d.logger.Debug("updating database schema")
if err := d.db.AutoMigrate(&Note{}); err != nil {
return err
}
d.logger.Debug("database schema updated")
return nil
}
func (d *Database) StartExpireThread() {
for {
d.logger.Debug("deleting expired notes")
trx := d.db.Where("expires_at <= ?", time.Now()).Delete(&Note{})
if trx.Error != nil {
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
}
d.logger.Debug("expired notes deleted")
wording := "second"
if d.expirationInterval > 1 {
wording += "s"
}
d.logger.Debug(fmt.Sprintf("waiting for %d %s before next expiration", d.expirationInterval, wording))
time.Sleep(time.Duration(d.expirationInterval) * time.Second)
}
}
func (d *Database) Get(id string) (*Note, error) {
var note Note
trx := d.db.Where("id = ?", id).Find(&note)
if trx.Error != nil {
d.logger.Warn("could not find note", slog.Any("error", trx.Error))
return nil, trx.Error
}
if note.ID != "" {
if note.DeleteAfterRead {
if err := d.Delete(note.ID); err != nil {
return nil, err
}
}
return &note, nil
}
return nil, nil
}
func (d *Database) Create(content []byte, password string, encrypted bool, expiration int, deleteAfterRead bool) (note *Note, err error) {
if expiration == 0 {
expiration = d.expiration
}
if !slices.Contains(d.expirations, expiration) {
validExpirations := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(d.expirations)), ", "), "[]")
return nil, fmt.Errorf("invalid expiration: must be one of %s", validExpirations)
}
note = &Note{
Content: content,
ExpiresAt: time.Now().Add(time.Duration(expiration) * time.Second),
Encrypted: encrypted,
DeleteAfterRead: deleteAfterRead,
}
if password != "" {
if err = internal.ValidatePassword(password); err != nil {
return nil, err
}
note.Content, err = internal.Encrypt(note.Content, password)
if err != nil {
return nil, err
}
note.Encrypted = true
}
trx := d.db.Create(note)
if trx.Error != nil {
d.logger.Warn("could not create note", slog.Any("error", trx.Error))
return nil, trx.Error
}
return note, nil
}
func (d *Database) Delete(id string) error {
trx := d.db.Where("id = ?", id).Delete(&Note{})
if trx.Error != nil {
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
return trx.Error
}
return nil
}