1
0
Fork 0
forked from jriou/coller
coller/src/server/db.go
Julien Riou 8e1dd686d3
feat: Rename password by encryption key
Signed-off-by: Julien Riou <julien@riou.xyz>
2025-09-24 07:09:01 +02:00

178 lines
4.5 KiB
Go

package server
import (
"fmt"
"log/slog"
"slices"
"strings"
"time"
"github.com/bwmarrin/snowflake"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"git.riou.xyz/jriou/coller/internal"
)
type Database struct {
logger *slog.Logger
db *gorm.DB
expirationInterval int
expirations []int
expiration int
languages []string
language string
node *snowflake.Node
}
var gconfig = &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
func NewDatabase(logger *slog.Logger, config *Config) (d *Database, err error) {
l := logger.With("module", "db")
logger.Debug("connecting to the database")
var db *gorm.DB
switch config.DatabaseType {
case "postgres":
db, err = gorm.Open(postgres.New(postgres.Config{DSN: config.DatabaseDsn}), gconfig)
case "sqlite":
db, err = gorm.Open(sqlite.Open(config.DatabaseDsn), gconfig)
default:
return nil, fmt.Errorf("database type '%s' not supported", config.DatabaseType)
}
if err != nil {
return nil, err
}
logger.Debug("connected to the database")
node, err := snowflake.NewNode(config.NodeID)
if err != nil {
return nil, err
}
d = &Database{
logger: l,
db: db,
expirationInterval: config.ExpirationInterval,
expirations: config.Expirations,
expiration: config.Expiration,
languages: internal.ToLowerStringSlice(config.Languages),
language: strings.ToLower(config.Language),
node: node,
}
if err = d.UpdateSchema(); err != nil {
return nil, err
}
go d.StartExpireThread()
return d, nil
}
func (d *Database) UpdateSchema() error {
d.logger.Debug("updating database schema")
if err := d.db.AutoMigrate(&Note{}); err != nil {
return err
}
d.logger.Debug("database schema updated")
return nil
}
func (d *Database) StartExpireThread() {
for {
d.logger.Debug("deleting expired notes")
trx := d.db.Where("expires_at <= ?", time.Now()).Delete(&Note{})
if trx.Error != nil {
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
}
d.logger.Debug("expired notes deleted")
wording := "second"
if d.expirationInterval > 1 {
wording += "s"
}
d.logger.Debug(fmt.Sprintf("waiting for %d %s before next expiration", d.expirationInterval, wording))
time.Sleep(time.Duration(d.expirationInterval) * time.Second)
}
}
func (d *Database) Get(id string) (*Note, error) {
var note Note
trx := d.db.Where("id = ?", id).Find(&note)
defer trx.Commit()
if trx.Error != nil {
d.logger.Warn("could not find note", slog.Any("error", trx.Error))
return nil, trx.Error
}
if note.ID != "" {
if note.DeleteAfterRead {
if err := d.Delete(note.ID); err != nil {
return nil, err
}
}
return &note, nil
}
return nil, nil
}
func (d *Database) Create(content []byte, encryptionKey string, encrypted bool, expiration int, deleteAfterRead bool, language string) (note *Note, err error) {
if expiration == 0 {
expiration = d.expiration
}
if !slices.Contains(d.expirations, expiration) {
validExpirations := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(d.expirations)), ", "), "[]")
return nil, fmt.Errorf("invalid expiration: must be one of %s", validExpirations)
}
if language == "" {
language = d.language
}
if !slices.Contains(d.languages, language) {
validLanguages := strings.Trim(strings.Join(strings.Fields(fmt.Sprint(d.languages)), ", "), "[]")
return nil, fmt.Errorf("invalid language: must be one of %s", validLanguages)
}
note = &Note{
ID: d.node.Generate().String(),
Content: content,
ExpiresAt: time.Now().Add(time.Duration(expiration) * time.Second),
Encrypted: encrypted,
DeleteAfterRead: deleteAfterRead,
Language: language,
}
if encryptionKey != "" {
if err = internal.ValidateEncryptionKey(encryptionKey); err != nil {
return nil, err
}
note.Content, err = internal.Encrypt(note.Content, encryptionKey)
if err != nil {
return nil, err
}
note.Encrypted = true
}
trx := d.db.Create(note)
defer trx.Commit()
if trx.Error != nil {
d.logger.Warn("could not create note", slog.Any("error", trx.Error))
return nil, trx.Error
}
return note, nil
}
func (d *Database) Delete(id string) error {
trx := d.db.Where("id = ?", id).Delete(&Note{})
defer trx.Commit()
if trx.Error != nil {
d.logger.Error("could not delete note", slog.Any("error", trx.Error))
return trx.Error
}
return nil
}