8000 Remove QuoteStr() usage by BetaCat0 · Pull Request #1360 · go-xorm/xorm · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Remove QuoteStr() usage #1360

Merged
merged 1 commit into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ func (engine *Engine) SupportInsertMany() bool {

// QuoteStr Engine's database use which character as quote.
// mysql, sqlite use ` and postgres use "
// Deprecated, use Quote() instead
func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr()
}
Expand All @@ -196,13 +197,10 @@ func (engine *Engine) Quote(value string) string {
return value
}

if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
return value
}

value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
buf := builder.StringBuilder{}
engine.QuoteTo(&buf, value)

return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
return buf.String()
}

// QuoteTo quotes string and writes into the buffer
Expand All @@ -216,20 +214,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
return
}

if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
buf.WriteString(value)
quotePair := engine.dialect.Quote("")

if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
_, _ = buf.WriteString(value)
return
} else {
prefix, suffix := quotePair[0], quotePair[1]

_ = buf.WriteByte(prefix)
for i := 0; i < len(value); i++ {
if value[i] == '.' {
_ = buf.WriteByte(suffix)
_ = buf.WriteByte('.')
_ = buf.WriteByte(prefix)
} else {
_ = buf.WriteByte(value[i])
}
}
_ = buf.WriteByte(suffix)
}

value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)

buf.WriteString(engine.dialect.QuoteStr())
buf.WriteString(value)
buf.WriteString(engine.dialect.QuoteStr())
}

func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
return engine.dialect.Quote(sql)
}

// SqlType will be deprecated, please use SQLType instead
Expand Down Expand Up @@ -1581,7 +1589,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case core.Time:
s := t.Format("2006-01-02 15:04:05") //time.RFC3339
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case core.Date:
v = t.Format("2006-01-02")
Expand Down
23 changes: 22 additions & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func rValue(bean interface{}) reflect.Value {

func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
//return reflect.TypeOf(sliceValue.Interface())
// return reflect.TypeOf(sliceValue.Interface())
return sliceValue.Type()
}

Expand Down Expand Up @@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}

func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}

replacer := strings.NewReplacer(replaceSeq...)

return replacer.Replace(value)
}

func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
for i := range cols {
cols[i] = quoteFunc(cols[i])
}
return strings.Join(cols, sep+" ")
}
22 changes: 21 additions & 1 deletion helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

package xorm

import "testing"
import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitTag(t *testing.T) {
var cases = []struct {
Expand All @@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
}
}
}

func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}

func TestQuoteColumns(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoteFunc := func(value string) string {
return "[" + value + "]"
}

assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
}
24 changes: 8 additions & 16 deletions session_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error

var sql string
if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr())
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
quoteColumns(colNames, session.engine.Quote, ","))
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
Expand Down Expand Up @@ -378,11 +372,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}
if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)",
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.Quote(", ")),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces)
} else {
Expand Down
17 changes: 9 additions & 8 deletions session_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed
}
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")

for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
// treat quote prefix, suffix and '`' as quotes
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
Expand Down Expand Up @@ -221,19 +222,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}

//for update action to like "column = column + ?"
// for update action to like "column = column + ?"
incColumns := session.statement.getInc()
for _, v := range incColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg)
}
//for update action to like "column = column - ?"
// for update action to like "column = column - ?"
decColumns := session.statement.getDec()
for _, v := range decColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg)
}
//for update action to like "column = expression"
// for update action to like "column = expression"
exprColumns := session.statement.getExpr()
for _, v := range exprColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
Expand Down Expand Up @@ -382,7 +383,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}

if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache {
//session.cacheUpdate(table, tableName, sqlStr, args...)
// session.cacheUpdate(table, tableName, sqlStr, args...)
session.engine.logger.Debug("[cacheUpdate] clear table ", tableName)
cacher.ClearIds(tableName)
cacher.ClearBeans(tableName)
Expand Down
2 changes: 1 addition & 1 deletion session_update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"testing"
"time"

"xorm.io/core"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)

func TestUpdateMap(t *testing.T) {
Expand Down
29 changes: 10 additions & 19 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package xorm

import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -398,7 +397,7 @@ func (statement *Statement) buildUpdates(bean interface{},
continue
}
} else {
//TODO: how to handler?
// TODO: how to handler?
panic("not supported")
}
} else {
Expand Down Expand Up @@ -579,21 +578,9 @@ func (statement *Statement) getExpr() map[string]exprParam {

func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
fields := strings.Split(strings.TrimSpace(c), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
}
return newColumns
}
Expand Down Expand Up @@ -764,7 +751,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
Expand All @@ -774,7 +763,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
Expand Down Expand Up @@ -1246,7 +1237,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {

CA6C var whereStr = sqls[1]

//TODO: for postgres only, if any other database?
// TODO: for postgres only, if any other database?
var paraStr string
if statement.Engine.dialect.DBType() == core.POSTGRES {
paraStr = "$"
Expand Down
11 changes: 10 additions & 1 deletion statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strings"
"testing"

"xorm.io/core"
"github.com/stretchr/testify/assert"
"xorm.io/core"
)

var colStrTests = []struct {
Expand Down Expand Up @@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
testEngine.Update(record)
assertGetRecord()
}

func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}

statement := createTestStatement()

quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}
0