forked from jriou/coller
Initial commit
Signed-off-by: Julien Riou <julien@riou.xyz>
This commit is contained in:
commit
ef9aca1f3b
26 changed files with 1668 additions and 0 deletions
149
src/server/db.go
Normal file
149
src/server/db.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.riou.xyz/jriou/coller/internal"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
logger *slog.Logger
|
||||
db *gorm.DB
|
||||
expirationInterval int
|
||||
expirations []int
|
||||
expiration int
|
||||
}
|
||||
|
||||
var gconfig = &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}
|
||||
|
||||
func NewDatabase(logger *slog.Logger, config *Config) (d *Database, err error) {
|
||||
l := logger.With("module", "db")
|
||||
|
||||
logger.Debug("connecting to the database")
|
||||
var db *gorm.DB
|
||||
|
||||
switch config.DatabaseType {
|
||||
case "postgres":
|
||||
db, err = gorm.Open(postgres.New(postgres.Config{DSN: config.DatabaseDsn}), gconfig)
|
||||
default:
|
||||
db, err = gorm.Open(sqlite.Open(config.DatabaseDsn), gconfig)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Debug("connected to the database")
|
||||
|
||||
d = &Database{
|
||||
logger: l,
|
||||
db: db,
|
||||
expirationInterval: config.ExpirationInterval,
|
||||
expirations: config.Expirations,
|
||||
expiration: config.Expiration,
|
||||
}
|
||||
|
||||
if err = d.UpdateSchema(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go d.StartExpireThread()
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Database) UpdateSchema() error {
|
||||
d.logger.Debug("updating database schema")
|
||||
if err := d.db.AutoMigrate(&Note{}); err != nil {
|
||||
return err
|
||||
}
|
||||
d.logger.Debug("database schema updated")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) StartExpireThread() {
|
||||
for {
|
||||
d.logger.Debug("deleting expired notes")
|
||||
trx := d.db.Where("expires_at <= ?", time.Now()).Delete(&Note{})
|
||||
if trx.Error != nil {
|
||||
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
|
||||
}
|
||||
d.logger.Debug("expired notes deleted")
|
||||
|
||||
wording := "second"
|
||||
if d.expirationInterval > 1 {
|
||||
wording += "s"
|
||||
}
|
||||
d.logger.Debug(fmt.Sprintf("waiting for %d %s before next expiration", d.expirationInterval, wording))
|
||||
time.Sleep(time.Duration(d.expirationInterval) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) Get(id string) (*Note, error) {
|
||||
var note Note
|
||||
trx := d.db.Where("id = ?", id).Find(¬e)
|
||||
if trx.Error != nil {
|
||||
d.logger.Warn("could not find note", slog.Any("error", trx.Error))
|
||||
return nil, trx.Error
|
||||
}
|
||||
if note.ID != "" {
|
||||
if note.DeleteAfterRead {
|
||||
if err := d.Delete(note.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return ¬e, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *Database) Create(content []byte, password string, encrypted bool, expiration int, deleteAfterRead bool) (note *Note, err error) {
|
||||
if expiration == 0 {
|
||||
expiration = d.expiration
|
||||
}
|
||||
if !slices.Contains(d.expirations, expiration) {
|
||||
validExpirations := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(d.expirations)), ", "), "[]")
|
||||
return nil, fmt.Errorf("invalid expiration: must be one of %s", validExpirations)
|
||||
}
|
||||
|
||||
note = &Note{
|
||||
Content: content,
|
||||
ExpiresAt: time.Now().Add(time.Duration(expiration) * time.Second),
|
||||
Encrypted: encrypted,
|
||||
DeleteAfterRead: deleteAfterRead,
|
||||
}
|
||||
if password != "" {
|
||||
if err = internal.ValidatePassword(password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
note.Content, err = internal.Encrypt(note.Content, password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
note.Encrypted = true
|
||||
}
|
||||
trx := d.db.Create(note)
|
||||
if trx.Error != nil {
|
||||
d.logger.Warn("could not create note", slog.Any("error", trx.Error))
|
||||
return nil, trx.Error
|
||||
}
|
||||
return note, nil
|
||||
}
|
||||
|
||||
func (d *Database) Delete(id string) error {
|
||||
trx := d.db.Where("id = ?", id).Delete(&Note{})
|
||||
if trx.Error != nil {
|
||||
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
|
||||
return trx.Error
|
||||
}
|
||||
return nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue