8000 Expose dialect Querier by mfridman · Pull Request #939 · pressly/goose · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Expose dialect Querier #939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package dialectquery
package dialect

import "fmt"
import (
"fmt"
)

type Clickhouse struct{}
// NewClickhouse returns a new [Querier] for Clickhouse dialect.
func NewClickhouse() Querier {
return &clickhouse{}
}

type clickhouse struct{}

var _ Querier = (*Clickhouse)(nil)
var _ Querier = (*clickhouse)(nil)

func (c *Clickhouse) CreateTable(tableName string) string {
func (c *clickhouse) CreateTable(tableName string) string {
q := `CREATE TABLE IF NOT EXISTS %s (
version_id Int64,
is_applied UInt8,
Expand All @@ -18,27 +25,27 @@ func (c *Clickhouse) CreateTable(tableName string) string {
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) InsertVersion(tableName string) string {
func (c *clickhouse) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) DeleteVersion(tableName string) string {
func (c *clickhouse) DeleteVersion(tableName string) string {
q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2`
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) GetMigrationByVersion(tableName string) string {
func (c *clickhouse) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) ListMigrations(tableName string) string {
func (c *clickhouse) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC`
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) GetLatestVersion(tableName string) string {
func (c *clickhouse) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package dialectquery
package dialect

import "fmt"
import (
"fmt"
)

type Mysql struct{}
// NewMysql returns a new [Querier] for MySQL dialect.
func NewMysql() QuerierExtender {
return &mysql{}
}

type mysql struct{}

var _ Querier = (*Mysql)(nil)
var _ QuerierExtender = (*mysql)(nil)

func (m *Mysql) CreateTable(tableName string) string {
func (m *mysql) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id bigint(20) unsigned NOT NULL AUTO_INCREMENT,
version_id bigint NOT NULL,
Expand All @@ -17,32 +24,32 @@ func (m *Mysql) CreateTable(tableName string) string {
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) InsertVersion(tableName string) string {
func (m *mysql) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) DeleteVersion(tableName string) string {
func (m *mysql) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) GetMigrationByVersion(tableName string) string {
func (m *mysql) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) ListMigrations(tableName string) string {
func (m *mysql) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) GetLatestVersion(tableName string) string {
func (m *mysql) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) TableExists(tableName string) string {
func (m *mysql) TableExists(tableName string) string {
schemaName, tableName := parseTableIdentifier(tableName)
if schemaName != "" {
q := `SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = '%s' AND table_name = '%s' )`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package dialectquery
package dialect

import (
"fmt"
"strings"
)

type Postgres struct{}
// NewPostgres returns a new [Querier] for PostgreSQL dialect.
func NewPostgres() QuerierExtender {
return &postgres{}
}

type postgres struct{}

var _ Querier = (*Postgres)(nil)
var _ QuerierExtender = (*postgres)(nil)

func (p *Postgres) CreateTable(tableName string) string {
func (p *postgres) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id integer PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
version_id bigint NOT NULL,
Expand All @@ -18,32 +24,32 @@ func (p *Postgres) CreateTable(tableName string) string {
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) InsertVersion(tableName string) string {
func (p *postgres) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) DeleteVersion(tableName string) string {
func (p *postgres) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=$1`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) GetMigrationByVersion(tableName string) string {
func (p *postgres) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) ListMigrations(tableName string) string {
func (p *postgres) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) GetLatestVersion(tableName string) string {
func (p *postgres) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) TableExists(tableName string) string {
func (p *postgres) TableExists(tableName string) string {
schemaName, tableName := parseTableIdentifier(tableName)
if schemaName != "" {
q := `SELECT EXISTS ( SELECT 1 FROM pg_tables WHERE schemaname = '%s' AND tablename = '%s' )`
Expand All @@ -52,3 +58,11 @@ func (p *Postgres) TableExists(tableName string) string {
q := `SELECT EXISTS ( SELECT 1 FROM pg_tables WHERE (current_schema() IS NULL OR schemaname = current_schema()) AND tablename = '%s' )`
return fmt.Sprintf(q, tableName)
}

func parseTableIdentifier(name string) (schema, table string) {
schema, table, found := strings.Cut(name, ".")
if !found {
return "", name
}
return schema, table
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
package dialectquery

import "strings"
package dialect

// Querier is the interface that wraps the basic methods to create a dialect specific query.
type Querier interface {
Expand All @@ -27,33 +25,3 @@ type Querier interface {
// table. Returns a nullable int64 value.
GetLatestVersion(tableName string) string
}

var _ Querier = (*QueryController)(nil)

type QueryController struct{ Querier }

// NewQueryController returns a new QueryController that wraps the given Querier.
func NewQueryController(querier Querier) *QueryController {
return &QueryController{Querier: querier}
}

// Optional methods

// TableExists returns the SQL query string to check if the version table exists. If the Querier
// does not implement this method, it will return an empty string.
//
// Returns a boolean value.
func (c *QueryController) TableExists(tableName string) string {
if t, ok := c.Querier.(interface{ TableExists(string) string }); ok {
return t.TableExists(tableName)
}
return ""
}

func parseTableIdentifier(name string) (schema, table string) {
schema, table, found := strings.Cut(name, ".")
if !found {
return "", name
}
return schema, table
}
28 changes: 28 additions & 0 deletions database/dialect/querier_extended.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dialect

// QuerierExtender is an extension of the Querier interface that provides optional optimizations and
// database-specific features. While not required by the core package, implementing these methods
// can improve performance and functionality for specific databases.
//
// IMPORTANT: This interface may be expanded in future versions. Implementors MUST be prepared to
// update their implementations when new methods are added, either by implementing the new
// functionality or returning an empty string.
//
// Example usage to verify implementation:
//
// var _ QuerierExtender = (*CustomQuerierExtended)(nil)
//
// In short, it's exported to allows implementors to have a compile-time check that they are
// implementing the interface correctly.
type QuerierExtender interface {
Querier

// TableExists returns the SQL query string to check if a table exists in the database.
// Implementing this method allows the system to optimize table existence checks by using
// database-specific system catalogs (e.g., pg_tables for PostgreSQL, sqlite_master for SQLite)
// instead of generic SQL queries.
//
// Return an empty string if the database does not provide an efficient way to check table
// existence.
TableExists(tableName string) string
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package dialectquery
package dialect

import "fmt"
import (
"fmt"
)

type Redshift struct{}
// Redshift returns a new [Querier] for Redshift dialect.
func NewRedshift() Querier {
return &redshift{}
}

type redshift struct{}

var _ Querier = (*Redshift)(nil)
var _ Querier = (*redshift)(nil)

func (r *Redshift) CreateTable(tableName string) string {
func (r *redshift) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id integer NOT NULL identity(1, 1),
version_id bigint NOT NULL,
Expand All @@ -17,27 +24,27 @@ func (r *Redshift) CreateTable(tableName string) string {
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) InsertVersion(tableName string) string {
func (r *redshift) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) DeleteVersion(tableName string) string {
func (r *redshift) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=$1`
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) GetMigrationByVersion(tableName string) string {
func (r *redshift) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) ListMigrations(tableName string) string {
func (r *redshift) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) GetLatestVersion(tableName string) string {
func (r *redshift) GetLatestVersion(tableName string) string {
q := `SELECT max(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package dialectquery
package dialect

import "fmt"
import (
"fmt"
)

type Sqlite3 struct{}
// NewSqlite3 returns a [Querier] for SQLite3 dialect.
func NewSqlite3() Querier {
return &sqlite3{}
}

type sqlite3 struct{}

var _ Querier = (*Sqlite3)(nil)
var _ Querier = (*sqlite3)(nil)

func (s *Sqlite3) CreateTable(tableName string) string {
func (s *sqlite3) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL,
Expand All @@ -16,27 +23,27 @@ func (s *Sqlite3) CreateTable(tableName string) string {
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) InsertVersion(tableName string) string {
func (s *sqlite3) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) DeleteVersion(tableName string) string {
func (s *sqlite3) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) GetMigrationByVersion(tableName string) string {
func (s *sqlite3) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) ListMigrations(tableName string) string {
func (s *sqlite3) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) GetLatestVersion(tableName string) string {
func (s *sqlite3) GetLatestVersion(tableName string) string {
q := `SELECT MAX(version_id) FROM %s`
return fmt.Sprintf(q, tableName)
}
Loading
Loading
0