package server import ( "fmt" "log/slog" "slices" "strings" "time" "git.riou.xyz/jriou/coller/internal" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" ) type Database struct { logger *slog.Logger db *gorm.DB expirationInterval int expirations []int expiration int } var gconfig = &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), } func NewDatabase(logger *slog.Logger, config *Config) (d *Database, err error) { l := logger.With("module", "db") logger.Debug("connecting to the database") var db *gorm.DB switch config.DatabaseType { case "postgres": db, err = gorm.Open(postgres.New(postgres.Config{DSN: config.DatabaseDsn}), gconfig) case "sqlite": db, err = gorm.Open(sqlite.Open(config.DatabaseDsn), gconfig) default: return nil, fmt.Errorf("database type '%s' not supported", config.DatabaseType) } if err != nil { return nil, err } logger.Debug("connected to the database") d = &Database{ logger: l, db: db, expirationInterval: config.ExpirationInterval, expirations: config.Expirations, expiration: config.Expiration, } if err = d.UpdateSchema(); err != nil { return nil, err } go d.StartExpireThread() return d, nil } func (d *Database) UpdateSchema() error { d.logger.Debug("updating database schema") if err := d.db.AutoMigrate(&Note{}); err != nil { return err } d.logger.Debug("database schema updated") return nil } func (d *Database) StartExpireThread() { for { d.logger.Debug("deleting expired notes") trx := d.db.Where("expires_at <= ?", time.Now()).Delete(&Note{}) if trx.Error != nil { d.logger.Error("could not delete note", slog.Any("error", trx.Error)) } d.logger.Debug("expired notes deleted") wording := "second" if d.expirationInterval > 1 { wording += "s" } d.logger.Debug(fmt.Sprintf("waiting for %d %s before next expiration", d.expirationInterval, wording)) time.Sleep(time.Duration(d.expirationInterval) * time.Second) } } func (d *Database) Get(id string) (*Note, error) { var note Note trx := d.db.Where("id = ?", id).Find(¬e) if trx.Error != nil { d.logger.Warn("could not find note", slog.Any("error", trx.Error)) return nil, trx.Error } if note.ID != "" { if note.DeleteAfterRead { if err := d.Delete(note.ID); err != nil { return nil, err } } return ¬e, nil } return nil, nil } func (d *Database) Create(content []byte, password string, encrypted bool, expiration int, deleteAfterRead bool) (note *Note, err error) { if expiration == 0 { expiration = d.expiration } if !slices.Contains(d.expirations, expiration) { validExpirations := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(d.expirations)), ", "), "[]") return nil, fmt.Errorf("invalid expiration: must be one of %s", validExpirations) } note = &Note{ Content: content, ExpiresAt: time.Now().Add(time.Duration(expiration) * time.Second), Encrypted: encrypted, DeleteAfterRead: deleteAfterRead, } if password != "" { if err = internal.ValidatePassword(password); err != nil { return nil, err } note.Content, err = internal.Encrypt(note.Content, password) if err != nil { return nil, err } note.Encrypted = true } trx := d.db.Create(note) if trx.Error != nil { d.logger.Warn("could not create note", slog.Any("error", trx.Error)) return nil, trx.Error } return note, nil } func (d *Database) Delete(id string) error { trx := d.db.Where("id = ?", id).Delete(&Note{}) if trx.Error != nil { d.logger.Error("could not delete note", slog.Any("error", trx.Error)) return trx.Error } return nil }