2016-12-28 20:03:40 -05:00
package testfixtures
import (
"database/sql"
"fmt"
2018-10-03 03:20:02 +08:00
"strings"
2016-12-28 20:03:40 -05:00
)
2020-06-17 15:07:58 -04:00
type postgreSQL struct {
2016-12-28 20:03:40 -05:00
baseHelper
2020-06-17 15:07:58 -04:00
useAlterConstraint bool
2020-08-13 21:54:46 -04:00
useDropConstraint bool
2020-06-17 15:07:58 -04:00
skipResetSequences bool
resetSequencesTo int64
2016-12-28 20:03:40 -05:00
tables [ ] string
sequences [ ] string
nonDeferrableConstraints [ ] pgConstraint
2020-08-13 21:54:46 -04:00
constraints [ ] pgConstraint
2018-10-03 03:20:02 +08:00
tablesChecksum map [ string ] string
2016-12-28 20:03:40 -05:00
}
type pgConstraint struct {
tableName string
constraintName string
2020-08-13 21:54:46 -04:00
definition string
2016-12-28 20:03:40 -05:00
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) init ( db * sql . DB ) error {
2016-12-28 20:03:40 -05:00
var err error
2018-10-03 03:20:02 +08:00
h . tables , err = h . tableNames ( db )
2016-12-28 20:03:40 -05:00
if err != nil {
return err
}
h . sequences , err = h . getSequences ( db )
if err != nil {
return err
}
h . nonDeferrableConstraints , err = h . getNonDeferrableConstraints ( db )
if err != nil {
return err
}
2020-08-13 21:54:46 -04:00
h . constraints , err = h . getConstraints ( db )
if err != nil {
return err
}
2016-12-28 20:03:40 -05:00
return nil
}
2020-06-17 15:07:58 -04:00
func ( * postgreSQL ) paramType ( ) int {
2016-12-28 20:03:40 -05:00
return paramTypeDollar
}
2020-06-17 15:07:58 -04:00
func ( * postgreSQL ) databaseName ( q queryable ) ( string , error ) {
2018-10-03 03:20:02 +08:00
var dbName string
err := q . QueryRow ( "SELECT current_database()" ) . Scan ( & dbName )
return dbName , err
2016-12-28 20:03:40 -05:00
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) tableNames ( q queryable ) ( [ ] string , error ) {
2016-12-28 20:03:40 -05:00
var tables [ ] string
2021-04-23 02:08:53 +02:00
const sql = `
2018-10-03 03:20:02 +08:00
SELECT pg_namespace . nspname || '.' || pg_class . relname
FROM pg_class
INNER JOIN pg_namespace ON pg_namespace . oid = pg_class . relnamespace
WHERE pg_class . relkind = 'r'
2020-08-13 21:54:46 -04:00
AND pg_namespace . nspname NOT IN ( ' pg_catalog ' , ' information_schema ' , ' crdb_internal ' )
2020-06-17 15:07:58 -04:00
AND pg_namespace . nspname NOT LIKE ' pg_toast % '
AND pg_namespace . nspname NOT LIKE ' \ _timescaledb % ' ;
2018-10-03 03:20:02 +08:00
`
rows , err := q . Query ( sql )
2016-12-28 20:03:40 -05:00
if err != nil {
return nil , err
}
defer rows . Close ( )
2018-10-03 03:20:02 +08:00
2016-12-28 20:03:40 -05:00
for rows . Next ( ) {
var table string
2018-10-03 03:20:02 +08:00
if err = rows . Scan ( & table ) ; err != nil {
return nil , err
}
2016-12-28 20:03:40 -05:00
tables = append ( tables , table )
}
2018-10-03 03:20:02 +08:00
if err = rows . Err ( ) ; err != nil {
return nil , err
}
2016-12-28 20:03:40 -05:00
return tables , nil
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) getSequences ( q queryable ) ( [ ] string , error ) {
2018-10-03 03:20:02 +08:00
const sql = `
SELECT pg_namespace . nspname || '.' || pg_class . relname AS sequence_name
FROM pg_class
INNER JOIN pg_namespace ON pg_namespace . oid = pg_class . relnamespace
WHERE pg_class . relkind = 'S'
2020-06-17 15:07:58 -04:00
AND pg_namespace . nspname NOT LIKE ' \ _timescaledb % '
2018-10-03 03:20:02 +08:00
`
2016-12-28 20:03:40 -05:00
2018-10-03 03:20:02 +08:00
rows , err := q . Query ( sql )
2016-12-28 20:03:40 -05:00
if err != nil {
return nil , err
}
defer rows . Close ( )
2018-10-03 03:20:02 +08:00
var sequences [ ] string
2016-12-28 20:03:40 -05:00
for rows . Next ( ) {
var sequence string
if err = rows . Scan ( & sequence ) ; err != nil {
return nil , err
}
sequences = append ( sequences , sequence )
}
2018-10-03 03:20:02 +08:00
if err = rows . Err ( ) ; err != nil {
return nil , err
}
2016-12-28 20:03:40 -05:00
return sequences , nil
}
2020-06-17 15:07:58 -04:00
func ( * postgreSQL ) getNonDeferrableConstraints ( q queryable ) ( [ ] pgConstraint , error ) {
2016-12-28 20:03:40 -05:00
var constraints [ ] pgConstraint
2021-04-23 02:08:53 +02:00
const sql = `
2018-10-03 03:20:02 +08:00
SELECT table_schema || '.' || table_name , constraint_name
FROM information_schema . table_constraints
WHERE constraint_type = ' FOREIGN KEY '
AND is_deferrable = ' NO '
2020-08-13 21:54:46 -04:00
AND table_schema < > ' crdb_internal '
2020-06-17 15:07:58 -04:00
AND table_schema NOT LIKE ' \ _timescaledb % '
2018-10-03 03:20:02 +08:00
`
rows , err := q . Query ( sql )
2016-12-28 20:03:40 -05:00
if err != nil {
return nil , err
}
defer rows . Close ( )
2020-08-13 21:54:46 -04:00
2016-12-28 20:03:40 -05:00
for rows . Next ( ) {
var constraint pgConstraint
2018-10-03 03:20:02 +08:00
if err = rows . Scan ( & constraint . tableName , & constraint . constraintName ) ; err != nil {
2016-12-28 20:03:40 -05:00
return nil , err
}
constraints = append ( constraints , constraint )
}
2018-10-03 03:20:02 +08:00
if err = rows . Err ( ) ; err != nil {
return nil , err
}
2016-12-28 20:03:40 -05:00
return constraints , nil
}
2020-08-13 21:54:46 -04:00
func ( h * postgreSQL ) getConstraints ( q queryable ) ( [ ] pgConstraint , error ) {
var constraints [ ] pgConstraint
2021-04-23 02:08:53 +02:00
const sql = `
2020-08-13 21:54:46 -04:00
SELECT conrelid : : regclass AS table_from , conname , pg_get_constraintdef ( pg_constraint . oid )
FROM pg_constraint
INNER JOIN pg_namespace ON pg_namespace . oid = pg_constraint . connamespace
WHERE contype = 'f'
AND pg_namespace . nspname NOT IN ( ' pg_catalog ' , ' information_schema ' , ' crdb_internal ' )
AND pg_namespace . nspname NOT LIKE ' pg_toast % '
AND pg_namespace . nspname NOT LIKE ' \ _timescaledb % ' ;
`
rows , err := q . Query ( sql )
if err != nil {
return nil , err
}
defer rows . Close ( )
for rows . Next ( ) {
var constraint pgConstraint
if err = rows . Scan (
& constraint . tableName ,
& constraint . constraintName ,
& constraint . definition ,
) ; err != nil {
return nil , err
}
constraints = append ( constraints , constraint )
}
if err = rows . Err ( ) ; err != nil {
return nil , err
}
return constraints , nil
}
func ( h * postgreSQL ) dropAndRecreateConstraints ( db * sql . DB , loadFn loadFunction ) ( err error ) {
defer func ( ) {
// Re-create constraints again after load
2021-04-23 02:08:53 +02:00
var b strings . Builder
2020-08-13 21:54:46 -04:00
for _ , constraint := range h . constraints {
2021-04-23 02:08:53 +02:00
b . WriteString ( fmt . Sprintf (
2020-08-13 21:54:46 -04:00
"ALTER TABLE %s ADD CONSTRAINT %s %s;" ,
h . quoteKeyword ( constraint . tableName ) ,
h . quoteKeyword ( constraint . constraintName ) ,
constraint . definition ,
2021-04-23 02:08:53 +02:00
) )
2020-08-13 21:54:46 -04:00
}
2021-04-23 02:08:53 +02:00
if _ , err2 := db . Exec ( b . String ( ) ) ; err2 != nil && err == nil {
2020-08-13 21:54:46 -04:00
err = err2
}
} ( )
2021-04-23 02:08:53 +02:00
var b strings . Builder
2020-08-13 21:54:46 -04:00
for _ , constraint := range h . constraints {
2021-04-23 02:08:53 +02:00
b . WriteString ( fmt . Sprintf (
2020-08-13 21:54:46 -04:00
"ALTER TABLE %s DROP CONSTRAINT %s;" ,
h . quoteKeyword ( constraint . tableName ) ,
h . quoteKeyword ( constraint . constraintName ) ,
2021-04-23 02:08:53 +02:00
) )
2020-08-13 21:54:46 -04:00
}
2021-04-23 02:08:53 +02:00
if _ , err := db . Exec ( b . String ( ) ) ; err != nil {
2020-08-13 21:54:46 -04:00
return err
}
tx , err := db . Begin ( )
if err != nil {
return err
}
defer tx . Rollback ( )
if err = loadFn ( tx ) ; err != nil {
return err
}
return tx . Commit ( )
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) disableTriggers ( db * sql . DB , loadFn loadFunction ) ( err error ) {
2016-12-28 20:03:40 -05:00
defer func ( ) {
2021-04-23 02:08:53 +02:00
var b strings . Builder
2016-12-28 20:03:40 -05:00
for _ , table := range h . tables {
2021-04-23 02:08:53 +02:00
b . WriteString ( fmt . Sprintf ( "ALTER TABLE %s ENABLE TRIGGER ALL;" , h . quoteKeyword ( table ) ) )
2016-12-28 20:03:40 -05:00
}
2021-04-23 02:08:53 +02:00
if _ , err2 := db . Exec ( b . String ( ) ) ; err2 != nil && err == nil {
2018-10-03 03:20:02 +08:00
err = err2
}
2016-12-28 20:03:40 -05:00
} ( )
tx , err := db . Begin ( )
if err != nil {
return err
}
2021-04-23 02:08:53 +02:00
var b strings . Builder
2016-12-28 20:03:40 -05:00
for _ , table := range h . tables {
2021-04-23 02:08:53 +02:00
b . WriteString ( fmt . Sprintf ( "ALTER TABLE %s DISABLE TRIGGER ALL;" , h . quoteKeyword ( table ) ) )
2016-12-28 20:03:40 -05:00
}
2021-04-23 02:08:53 +02:00
if _ , err = tx . Exec ( b . String ( ) ) ; err != nil {
2016-12-28 20:03:40 -05:00
return err
}
if err = loadFn ( tx ) ; err != nil {
tx . Rollback ( )
return err
}
return tx . Commit ( )
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) makeConstraintsDeferrable ( db * sql . DB , loadFn loadFunction ) ( err error ) {
2016-12-28 20:03:40 -05:00
defer func ( ) {
// ensure constraint being not deferrable again after load
2021-04-23 02:08:53 +02:00
var b strings . Builder
2016-12-28 20:03:40 -05:00
for _ , constraint := range h . nonDeferrableConstraints {
2021-04-23 02:08:53 +02:00
b . WriteString ( fmt . Sprintf ( "ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;" , h . quoteKeyword ( constraint . tableName ) , h . quoteKeyword ( constraint . constraintName ) ) )
2016-12-28 20:03:40 -05:00
}
2021-04-23 02:08:53 +02:00
if _ , err2 := db . Exec ( b . String ( ) ) ; err2 != nil && err == nil {
2018-10-03 03:20:02 +08:00
err = err2
}
2016-12-28 20:03:40 -05:00
} ( )
2021-04-23 02:08:53 +02:00
var b strings . Builder
2016-12-28 20:03:40 -05:00
for _ , constraint := range h . nonDeferrableConstraints {
2021-04-23 02:08:53 +02:00
b . WriteString ( fmt . Sprintf ( "ALTER TABLE %s ALTER CONSTRAINT %s DEFERRABLE;" , h . quoteKeyword ( constraint . tableName ) , h . quoteKeyword ( constraint . constraintName ) ) )
2016-12-28 20:03:40 -05:00
}
2021-04-23 02:08:53 +02:00
if _ , err := db . Exec ( b . String ( ) ) ; err != nil {
2016-12-28 20:03:40 -05:00
return err
}
tx , err := db . Begin ( )
if err != nil {
return err
}
2018-10-03 03:20:02 +08:00
defer tx . Rollback ( )
2016-12-28 20:03:40 -05:00
if _ , err = tx . Exec ( "SET CONSTRAINTS ALL DEFERRED" ) ; err != nil {
2018-10-03 03:20:02 +08:00
return err
2016-12-28 20:03:40 -05:00
}
if err = loadFn ( tx ) ; err != nil {
return err
}
return tx . Commit ( )
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) disableReferentialIntegrity ( db * sql . DB , loadFn loadFunction ) ( err error ) {
2016-12-28 20:03:40 -05:00
// ensure sequences being reset after load
2020-06-17 15:07:58 -04:00
if ! h . skipResetSequences {
defer func ( ) {
if err2 := h . resetSequences ( db ) ; err2 != nil && err == nil {
err = err2
}
} ( )
}
2016-12-28 20:03:40 -05:00
2020-08-13 21:54:46 -04:00
if h . useDropConstraint {
return h . dropAndRecreateConstraints ( db , loadFn )
}
2020-06-17 15:07:58 -04:00
if h . useAlterConstraint {
2016-12-28 20:03:40 -05:00
return h . makeConstraintsDeferrable ( db , loadFn )
}
2018-10-03 03:20:02 +08:00
return h . disableTriggers ( db , loadFn )
2016-12-28 20:03:40 -05:00
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) resetSequences ( db * sql . DB ) error {
resetSequencesTo := h . resetSequencesTo
if resetSequencesTo == 0 {
resetSequencesTo = 10000
}
2016-12-28 20:03:40 -05:00
for _ , sequence := range h . sequences {
_ , err := db . Exec ( fmt . Sprintf ( "SELECT SETVAL('%s', %d)" , sequence , resetSequencesTo ) )
if err != nil {
return err
}
}
return nil
}
2018-10-03 03:20:02 +08:00
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) isTableModified ( q queryable , tableName string ) ( bool , error ) {
2018-10-03 03:20:02 +08:00
checksum , err := h . getChecksum ( q , tableName )
if err != nil {
return false , err
}
oldChecksum := h . tablesChecksum [ tableName ]
return oldChecksum == "" || checksum != oldChecksum , nil
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) afterLoad ( q queryable ) error {
2018-10-03 03:20:02 +08:00
if h . tablesChecksum != nil {
return nil
}
h . tablesChecksum = make ( map [ string ] string , len ( h . tables ) )
for _ , t := range h . tables {
checksum , err := h . getChecksum ( q , t )
if err != nil {
return err
}
h . tablesChecksum [ t ] = checksum
}
return nil
}
2020-06-17 15:07:58 -04:00
func ( h * postgreSQL ) getChecksum ( q queryable , tableName string ) ( string , error ) {
2018-10-03 03:20:02 +08:00
sqlStr := fmt . Sprintf ( `
2020-08-13 21:54:46 -04:00
SELECT md5 ( CAST ( ( json_agg ( t . * ) ) AS TEXT ) )
2018-10-03 03:20:02 +08:00
FROM % s AS t
` ,
h . quoteKeyword ( tableName ) ,
)
var checksum sql . NullString
if err := q . QueryRow ( sqlStr ) . Scan ( & checksum ) ; err != nil {
return "" , err
}
return checksum . String , nil
}
2020-06-17 15:07:58 -04:00
func ( * postgreSQL ) quoteKeyword ( s string ) string {
2018-10-03 03:20:02 +08:00
parts := strings . Split ( s , "." )
for i , p := range parts {
parts [ i ] = fmt . Sprintf ( ` "%s" ` , p )
}
return strings . Join ( parts , "." )
}