151 lines
3.7 KiB
Go
151 lines
3.7 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.riou.xyz/jriou/coller/internal"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
type Database struct {
|
|
logger *slog.Logger
|
|
db *gorm.DB
|
|
expirationInterval int
|
|
expirations []int
|
|
expiration int
|
|
}
|
|
|
|
var gconfig = &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
}
|
|
|
|
func NewDatabase(logger *slog.Logger, config *Config) (d *Database, err error) {
|
|
l := logger.With("module", "db")
|
|
|
|
logger.Debug("connecting to the database")
|
|
var db *gorm.DB
|
|
|
|
switch config.DatabaseType {
|
|
case "postgres":
|
|
db, err = gorm.Open(postgres.New(postgres.Config{DSN: config.DatabaseDsn}), gconfig)
|
|
case "sqlite":
|
|
db, err = gorm.Open(sqlite.Open(config.DatabaseDsn), gconfig)
|
|
default:
|
|
return nil, fmt.Errorf("database type '%s' not supported", config.DatabaseType)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger.Debug("connected to the database")
|
|
|
|
d = &Database{
|
|
logger: l,
|
|
db: db,
|
|
expirationInterval: config.ExpirationInterval,
|
|
expirations: config.Expirations,
|
|
expiration: config.Expiration,
|
|
}
|
|
|
|
if err = d.UpdateSchema(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
go d.StartExpireThread()
|
|
|
|
return d, nil
|
|
}
|
|
|
|
func (d *Database) UpdateSchema() error {
|
|
d.logger.Debug("updating database schema")
|
|
if err := d.db.AutoMigrate(&Note{}); err != nil {
|
|
return err
|
|
}
|
|
d.logger.Debug("database schema updated")
|
|
return nil
|
|
}
|
|
|
|
func (d *Database) StartExpireThread() {
|
|
for {
|
|
d.logger.Debug("deleting expired notes")
|
|
trx := d.db.Where("expires_at <= ?", time.Now()).Delete(&Note{})
|
|
if trx.Error != nil {
|
|
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
|
|
}
|
|
d.logger.Debug("expired notes deleted")
|
|
|
|
wording := "second"
|
|
if d.expirationInterval > 1 {
|
|
wording += "s"
|
|
}
|
|
d.logger.Debug(fmt.Sprintf("waiting for %d %s before next expiration", d.expirationInterval, wording))
|
|
time.Sleep(time.Duration(d.expirationInterval) * time.Second)
|
|
}
|
|
}
|
|
|
|
func (d *Database) Get(id string) (*Note, error) {
|
|
var note Note
|
|
trx := d.db.Where("id = ?", id).Find(¬e)
|
|
if trx.Error != nil {
|
|
d.logger.Warn("could not find note", slog.Any("error", trx.Error))
|
|
return nil, trx.Error
|
|
}
|
|
if note.ID != "" {
|
|
if note.DeleteAfterRead {
|
|
if err := d.Delete(note.ID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return ¬e, nil
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (d *Database) Create(content []byte, password string, encrypted bool, expiration int, deleteAfterRead bool) (note *Note, err error) {
|
|
if expiration == 0 {
|
|
expiration = d.expiration
|
|
}
|
|
if !slices.Contains(d.expirations, expiration) {
|
|
validExpirations := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(d.expirations)), ", "), "[]")
|
|
return nil, fmt.Errorf("invalid expiration: must be one of %s", validExpirations)
|
|
}
|
|
|
|
note = &Note{
|
|
Content: content,
|
|
ExpiresAt: time.Now().Add(time.Duration(expiration) * time.Second),
|
|
Encrypted: encrypted,
|
|
DeleteAfterRead: deleteAfterRead,
|
|
}
|
|
if password != "" {
|
|
if err = internal.ValidatePassword(password); err != nil {
|
|
return nil, err
|
|
}
|
|
note.Content, err = internal.Encrypt(note.Content, password)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
note.Encrypted = true
|
|
}
|
|
trx := d.db.Create(note)
|
|
if trx.Error != nil {
|
|
d.logger.Warn("could not create note", slog.Any("error", trx.Error))
|
|
return nil, trx.Error
|
|
}
|
|
return note, nil
|
|
}
|
|
|
|
func (d *Database) Delete(id string) error {
|
|
trx := d.db.Where("id = ?", id).Delete(&Note{})
|
|
if trx.Error != nil {
|
|
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
|
|
return trx.Error
|
|
}
|
|
return nil
|
|
}
|