benchito/config.go
Julien Riou d1e3da5189
feat: Initial release
Signed-off-by: Julien Riou <julien@riou.xyz>
2022-10-09 09:26:56 +02:00

110 lines
2.6 KiB
Go

package main
import (
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"time"
"github.com/go-sql-driver/mysql"
"gopkg.in/yaml.v2"
)
// Config to store all configurations (from command line, from file, etc)
type Config struct {
Driver string `yaml:"driver"`
Connections int `yaml:"connections"`
Query string `yaml:"query"`
Duration time.Duration `yaml:"duration"`
Reconnect bool `yaml:"reconnect"`
DSN string `yaml:"dsn"`
Host string `yaml:"host"`
Port int `yaml:"port"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
TLS string `yaml:"tls"`
ConnectTimeout int `yaml:"connect_timeout"`
}
// NewConfig creates a Config struct
func NewConfig() *Config {
return &Config{}
}
// Read YaML configuration file from disk
func (c *Config) Read(file string) error {
file, err := filepath.Abs(file)
if err != nil {
return err
}
yamlFile, err := ioutil.ReadFile(file)
if err != nil {
return err
}
err = yaml.Unmarshal(yamlFile, &c)
if err != nil {
return err
}
return nil
}
// ParseDSN detects the database driver then creates the DSN accordingly
func (c *Config) ParseDSN() {
if c.DSN == "" {
switch c.Driver {
case "postgres":
c.DSN = c.parsePostgresDSN()
case "mysql":
c.DSN = c.parseMysqlDSN()
}
}
}
func (c *Config) parsePostgresDSN() string {
var parameters []string
if c.Host != "" {
parameters = append(parameters, fmt.Sprintf("host=%s", c.Host))
}
if c.Port != 0 {
parameters = append(parameters, fmt.Sprintf("port=%d", c.Port))
}
if c.User != "" {
parameters = append(parameters, fmt.Sprintf("user=%s", c.User))
}
if c.Password != "" {
parameters = append(parameters, fmt.Sprintf("password=%s", c.Password))
}
if c.Database != "" {
parameters = append(parameters, fmt.Sprintf("database=%s", c.Database))
}
if c.ConnectTimeout != 0 {
parameters = append(parameters, fmt.Sprintf("connect_timeout=%d", c.ConnectTimeout))
}
if AppName != "" {
parameters = append(parameters, fmt.Sprintf("application_name=%s", AppName))
}
if c.TLS != "" {
parameters = append(parameters, fmt.Sprintf("sslmode=%s", c.TLS))
}
return strings.Join(parameters, " ")
}
func (c *Config) parseMysqlDSN() (dsn string) {
config := mysql.NewConfig()
config.Addr = c.Host
if c.Port != 0 {
config.Addr += fmt.Sprintf(":%d", c.Port)
}
config.User = c.User
config.Passwd = c.Password
config.DBName = c.Database
config.TLSConfig = c.TLS
return config.FormatDSN()
}