DB migration, small technicalities.

* DB migration rewritten:
  * Migration scripts are embedded as FS.
  * Migration versions are handled automatically.
* Use generics in utils.
This commit is contained in:
zegkljan 2022-09-22 22:05:41 +02:00
parent c1cdd4f904
commit c9377b04fc
4 changed files with 105 additions and 36 deletions

View File

@ -4,6 +4,7 @@ import (
"archive/zip"
"bytes"
"context"
"embed"
"encoding/json"
"fmt"
"io"
@ -12,8 +13,6 @@ import (
"os"
"time"
_ "embed"
"cernobor.cz/oko-server/errs"
"cernobor.cz/oko-server/models"
"crawshaw.io/sqlite"
@ -23,11 +22,8 @@ import (
"github.com/sirupsen/logrus"
)
//go:embed sql_schema/V1_init.sql
var sql_v1 string
//go:embed sql_schema/V2_proposals.sql
var sql_v2 string
//go:embed sql_schema/V*.sql
var sqlSchema embed.FS
type Server struct {
config ServerConfig
@ -240,7 +236,7 @@ func (s *Server) initDB(reinit bool) error {
defer s.requestCheckpoint()
if reinit {
s.log.Warn("REinitializing DB.")
s.log.Warn("Reinitializing DB.")
tables := []string{}
err := sqlitex.Exec(conn, "select name from sqlite_master where type = 'table'", func(stmt *sqlite.Stmt) error {
tables = append(tables, stmt.ColumnText(0))
@ -262,6 +258,16 @@ func (s *Server) initDB(reinit bool) error {
return fmt.Errorf("failed to reset user version: %w", err)
}
}
err := s.migrateDb(conn)
if err != nil {
return fmt.Errorf("failed to migrate db: %w", err)
}
return nil
}
func (s *Server) migrateDb(conn *sqlite.Conn) error {
var version int
err := sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error {
version = stmt.ColumnInt(0)
@ -270,21 +276,96 @@ func (s *Server) initDB(reinit bool) error {
if err != nil {
return fmt.Errorf("failed to get user version: %w", err)
}
s.log.Debugf("Current db version: %d", version)
if version <= 0 {
s.log.Debugf("Running db migration V1")
err = sqlitex.ExecScript(conn, sql_v1)
entries, err := sqlSchema.ReadDir("sql_schema")
if err != nil {
return fmt.Errorf("failed to run V1 init script")
return fmt.Errorf("failed to read sql_schema migrations: %w", err)
}
type migration struct {
file string
version int
name string
}
if version <= 1 {
s.log.Debugf("Running db migration V2")
err = sqlitex.ExecScript(conn, sql_v2)
pattern := regexp.MustCompile("^V([0-9]+)_(.*)[.][sS][qQ][lL]$")
migrations := []migration{}
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() {
return fmt.Errorf("embedded sql migration '%s' is a directory", name)
}
matches := pattern.FindStringSubmatch(name)
if matches == nil {
return fmt.Errorf("embedded sql migration '%s' does not match the filename pattern", name)
}
if len(matches) != 3 {
return fmt.Errorf("embedded sql migration '%s' does not have the correct number of submatches", name)
}
version, err := strconv.Atoi(matches[1])
if err != nil {
return fmt.Errorf("failed to run V2 init script")
return fmt.Errorf("failed to parse version number of migration '%s': %w", name, err)
}
migName := matches[2]
file := path.Join("sql_schema", name)
migrations = append(migrations, migration{
file: file,
version: version,
name: migName,
})
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
})
for _, migration := range migrations {
if version >= migration.version {
s.log.Debugf("Skipping migration version %d because current version %d is not smaller.", migration.version, version)
continue
}
migContent, err := sqlSchema.ReadFile(migration.file)
if err != nil {
return fmt.Errorf("failed to read embedded migration '%s': %w", migration.file, err)
}
err = func() (err error) {
rollback := sqlitex.Save(conn)
defer func() {
if err != nil {
s.log.Info("Rolling back last migration attempt.")
}
rollback(&err)
}()
s.log.Infof("Executing migration V%d - %s", migration.version, migration.name)
err = sqlitex.ExecScript(conn, string(migContent))
if err != nil {
return fmt.Errorf("failed to execute migration '%s': %w", migration.name, err)
}
err = sqlitex.Exec(conn, fmt.Sprintf("PRAGMA user_version = %d", migration.version), nil)
if err != nil {
return fmt.Errorf("failed to set user_version in db: %w", err)
}
err = sqlitex.Exec(conn, "PRAGMA user_version", func(stmt *sqlite.Stmt) error {
version = stmt.ColumnInt(0)
return nil
})
if err != nil {
return fmt.Errorf("failed to get user_version: %w", err)
}
s.log.Infof("Migrated db to version: %d", version)
return nil
}()
if err != nil {
return err
}
}
return nil
}
return nil
@ -300,7 +381,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error)
var id *int64
if hc.Exists {
err = sqlitex.Exec(conn, "select id from users where name = ?", func(stmt *sqlite.Stmt) error {
id = ptrInt64(stmt.ColumnInt64(0))
id = ptr(stmt.ColumnInt64(0))
return nil
}, hc.Name)
if sqlite.ErrCode(err) != sqlite.SQLITE_OK {
@ -314,7 +395,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error)
}
} else {
err = sqlitex.Exec(conn, "insert into users(name) values(?)", func(stmt *sqlite.Stmt) error {
id = ptrInt64(stmt.ColumnInt64(0))
id = ptr(stmt.ColumnInt64(0))
return nil
}, hc.Name)
if sqlite.ErrCode(err) == sqlite.SQLITE_CONSTRAINT_UNIQUE {
@ -323,7 +404,7 @@ func (s *Server) handshake(hc models.HandshakeChallenge) (models.UserID, error)
if sqlite.ErrCode(err) != sqlite.SQLITE_OK {
return 0, err
}
id = ptrInt64(conn.LastInsertRowID())
id = ptr(conn.LastInsertRowID())
s.requestCheckpoint()
}
return *id, nil
@ -925,7 +1006,7 @@ func (s *Server) getPhoto(featureID models.FeatureID, photoID models.FeaturePhot
if found {
return fmt.Errorf("multiple photos returned for feature id %d, photo id %d", featureID, photoID)
}
contentType = ptrString(stmt.ColumnText(0))
contentType = ptr(stmt.ColumnText(0))
data = make([]byte, stmt.ColumnLen(1))
stmt.ColumnBytes(1, data)
found = true

View File

@ -21,5 +21,3 @@ CREATE TABLE feature_photos (
thumbnail_contents blob NOT NULL,
contents blob NOT NULL
);
PRAGMA user_version = 1;

View File

@ -3,5 +3,3 @@ CREATE TABLE proposals (
description text NOT NULL,
how text NOT NULL
);
PRAGMA user_version = 2;

View File

@ -13,15 +13,7 @@ import (
"cernobor.cz/oko-server/models"
)
func ptrInt(x int) *int {
return &x
}
func ptrInt64(x int64) *int64 {
return &x
}
func ptrString(x string) *string {
func ptr[T any](x T) *T {
return &x
}