wargh/db/db.go
2024-11-17 22:04:42 +02:00

138 lines
2.5 KiB
Go

package db
import (
"database/sql"
"log"
"os"
"path"
"strconv"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/umpc/go-sortedmap"
"github.com/umpc/go-sortedmap/asc"
)
type DBConfig struct {
DBPath string
MigrationsPath string
}
var config *DBConfig = nil
func Init(_config *DBConfig) {
config = _config
}
func checkConfig() {
if config == nil {
log.Fatal("DBConfig not initialized! Call db.Init!")
}
}
func open(dbPath string) (*sql.DB, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
}
/* if database is empty. we must create
'_migration` table and apply all releveant migrations
*/
_, err = db.Query("SELECT `_seq` FROM `_migration` LIMIT 1;")
if err != nil {
// create table
_, err = db.Exec(`
CREATE TABLE _migration (
_seq INTEGER PRIMARY KEY UNIQUE
);`)
// can't create migration table - fatal
if err != nil {
return nil, err
}
}
return db, nil
}
func applyMigrations(db *sql.DB) {
entries, err := os.ReadDir(config.MigrationsPath)
if err != nil {
log.Fatal(err)
}
// load migrations from directory
migrations := sortedmap.New(len(entries), asc.Int)
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
split := strings.Split(name, "_")
if len(split) < 2 {
log.Printf("%s does not follow 000_name.sql convention\n", name)
continue
}
seq, err := strconv.Atoi(split[0])
if err != nil {
log.Fatal(err)
}
migrations.Insert(name, seq)
}
if migrations.Len() == 0 {
log.Fatal("no migrations!")
}
// detect last database migration
var seq int = -1
row := db.QueryRow("SELECT `_seq` FROM `_migration` ORDER BY `_seq` DESC LIMIT 1;")
if row.Err() == nil {
row.Scan(&seq)
}
// apply migrations starting from seq
iter, err := migrations.BoundedIterCh(false, seq+1, nil)
if err != nil {
// no migrations to apply
return
}
defer iter.Close()
for m := range iter.Records() {
path := path.Join(config.MigrationsPath, m.Key.(string))
bytes, err := os.ReadFile(path)
if err != nil {
log.Fatal(err)
}
sql := string(bytes[:])
_, err = db.Exec(sql)
if err != nil {
log.Fatal(err)
}
// successfuly applied migration, so bump seq and insert in table
seq = m.Val.(int)
_, err = db.Exec("INSERT INTO `_migration` (_seq) VALUES (?);", seq)
if err != nil {
log.Fatal(err)
}
log.Printf("Applied migration %d\n", seq)
}
}
func Open() (*sql.DB, error) {
checkConfig()
db, err := open(config.DBPath)
if err != nil {
return nil, err
}
applyMigrations(db)
return db, nil
}