From 81d9c54d27827805eae58c5b366793cd73685382 Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Thu, 23 Sep 2021 18:22:53 +0300 Subject: [PATCH] entc/gen: allow spaces in enum fields --- .golangci.yml | 4 ++-- dialect/sql/schema/mysql.go | 12 +++++++++--- dialect/sql/schema/mysql_test.go | 8 +++++--- entc/gen/func.go | 2 +- entc/gen/type.go | 8 ++++---- entc/gen/type_test.go | 1 + entc/integration/docker-compose.yaml | 2 +- entc/integration/ent/migrate/schema.go | 2 +- entc/integration/ent/schema/user.go | 2 +- entc/integration/ent/user/user.go | 3 ++- entc/integration/gremlin/ent/user/user.go | 3 ++- 11 files changed, 29 insertions(+), 18 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 2d7956e43d..05b56a388c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,8 +7,8 @@ linters-settings: dupl: threshold: 100 funlen: - lines: 110 - statements: 100 + lines: 115 + statements: 115 goheader: template: |- Copyright 2019-present Facebook Inc. All rights reserved. diff --git a/dialect/sql/schema/mysql.go b/dialect/sql/schema/mysql.go index a93b369dbd..0641c7a68b 100644 --- a/dialect/sql/schema/mysql.go +++ b/dialect/sql/schema/mysql.go @@ -471,9 +471,15 @@ func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error { c.Type = field.TypeJSON case "enum": c.Type = field.TypeEnum - c.Enums = make([]string, len(parts)-1) - for i, e := range parts[1:] { - c.Enums[i] = strings.Trim(e, "'") + // Parse the enum values according to the MySQL format. + // github.com/mysql/mysql-server/blob/8.0/sql/field.cc#Field_enum::sql_type + values := strings.TrimSuffix(strings.TrimPrefix(c.typ, "enum("), ")") + if values == "" { + return fmt.Errorf("mysql: unexpected enum type: %q", c.typ) + } + parts := strings.Split(values, "','") + for i := range parts { + c.Enums = append(c.Enums, strings.Trim(parts[i], "'")) } case "char": c.Type = field.TypeOther diff --git a/dialect/sql/schema/mysql_test.go b/dialect/sql/schema/mysql_test.go index 7a0650c824..6746a421ef 100644 --- a/dialect/sql/schema/mysql_test.go +++ b/dialect/sql/schema/mysql_test.go @@ -307,8 +307,9 @@ func TestMySQL_Create(t *testing.T) { Columns: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, {Name: "name", Type: field.TypeString, Nullable: true}, - {Name: "enums1", Type: field.TypeEnum, Enums: []string{"a", "b"}}, // add enum. - {Name: "enums2", Type: field.TypeEnum, Enums: []string{"a"}}, // remove enum. + {Name: "enums1", Type: field.TypeEnum, Enums: []string{"a", "b"}}, // add enum. + {Name: "enums2", Type: field.TypeEnum, Enums: []string{"a"}}, // remove enum. + {Name: "enums3", Type: field.TypeEnum, Enums: []string{"a", "b c"}}, // no changes. }, PrimaryKey: []*Column{ {Name: "id", Type: field.TypeInt, Increment: true}, @@ -324,7 +325,8 @@ func TestMySQL_Create(t *testing.T) { AddRow("id", "bigint(20)", "NO", "PRI", "NULL", "auto_increment", "", "", nil, nil). AddRow("name", "varchar(255)", "YES", "YES", "NULL", "", "", "", nil, nil). AddRow("enums1", "enum('a')", "YES", "NO", "NULL", "", "", "", nil, nil). - AddRow("enums2", "enum('b', 'a')", "NO", "YES", "NULL", "", "", "", nil, nil)) + AddRow("enums2", "enum('b', 'a')", "NO", "YES", "NULL", "", "", "", nil, nil). + AddRow("enums3", "enum('a', 'b c')", "NO", "YES", "NULL", "", "", "", nil, nil)) mock.ExpectQuery(escape("SELECT `index_name`, `column_name`, `sub_part`, `non_unique`, `seq_in_index` FROM `INFORMATION_SCHEMA`.`STATISTICS` WHERE `TABLE_SCHEMA` = (SELECT DATABASE()) AND `TABLE_NAME` = ? ORDER BY `index_name`, `seq_in_index`")). WithArgs("users"). WillReturnRows(sqlmock.NewRows([]string{"index_name", "column_name", "sub_part", "non_unique", "seq_in_index"}). diff --git a/entc/gen/func.go b/entc/gen/func.go index fd7e2f3d27..dccbeec127 100644 --- a/entc/gen/func.go +++ b/entc/gen/func.go @@ -114,7 +114,7 @@ func plural(name string) string { } func isSeparator(r rune) bool { - return r == '_' || r == '-' + return r == '_' || r == '-' || unicode.IsSpace(r) } func pascalWords(words []string) string { diff --git a/entc/gen/type.go b/entc/gen/type.go index d18154a3d3..1d1c600225 100644 --- a/entc/gen/type.go +++ b/entc/gen/type.go @@ -1371,16 +1371,16 @@ func (f Field) enums(lf *load.Field) ([]Enum, error) { enums := make([]Enum, 0, len(lf.Enums)) values := make(map[string]bool, len(lf.Enums)) for i := range lf.Enums { - switch name, value := lf.Enums[i].N, lf.Enums[i].V; { + switch name, value := f.EnumName(lf.Enums[i].N), lf.Enums[i].V; { case value == "": return nil, fmt.Errorf("%q field value cannot be empty", f.Name) case values[value]: return nil, fmt.Errorf("duplicate values %q for enum field %q", value, f.Name) - case strings.IndexFunc(value, unicode.IsSpace) != -1: - return nil, fmt.Errorf("enum value %q cannot contain spaces", value) + case !token.IsIdentifier(name): + return nil, fmt.Errorf("enum %q does not have a valid Go indetifier (%q)", value, name) default: values[value] = true - enums = append(enums, Enum{Name: f.EnumName(name), Value: value}) + enums = append(enums, Enum{Name: name, Value: value}) } } if value := lf.DefaultValue; value != nil { diff --git a/entc/gen/type_test.go b/entc/gen/type_test.go index 53e916d9ce..d70c31314c 100644 --- a/entc/gen/type_test.go +++ b/entc/gen/type_test.go @@ -147,6 +147,7 @@ func TestField_EnumName(t *testing.T) { {"MP4", "TypeMP4"}, {"unknown", "TypeUnknown"}, {"user_data", "TypeUserData"}, + {"test user", "TypeTestUser"}, } for _, tt := range tests { require.Equal(t, tt.enum, Field{Name: "Type"}.EnumName(tt.name)) diff --git a/entc/integration/docker-compose.yaml b/entc/integration/docker-compose.yaml index d3bff21e7d..2db7edb049 100644 --- a/entc/integration/docker-compose.yaml +++ b/entc/integration/docker-compose.yaml @@ -119,7 +119,7 @@ services: gremlin: platform: linux/amd64 image: entgo/gremlin-server - build: gremlin-server + build: compose/gremlin-server restart: on-failure ports: - 8182:8182 diff --git a/entc/integration/ent/migrate/schema.go b/entc/integration/ent/migrate/schema.go index 841ef27517..3e2259e19b 100644 --- a/entc/integration/ent/migrate/schema.go +++ b/entc/integration/ent/migrate/schema.go @@ -371,7 +371,7 @@ var ( {Name: "address", Type: field.TypeString, Nullable: true}, {Name: "phone", Type: field.TypeString, Unique: true, Nullable: true}, {Name: "password", Type: field.TypeString, Nullable: true}, - {Name: "role", Type: field.TypeEnum, Enums: []string{"user", "admin", "free-user"}, Default: "user"}, + {Name: "role", Type: field.TypeEnum, Enums: []string{"user", "admin", "free-user", "test user"}, Default: "user"}, {Name: "sso_cert", Type: field.TypeString, Nullable: true}, {Name: "group_blocked", Type: field.TypeInt, Nullable: true}, {Name: "user_spouse", Type: field.TypeInt, Unique: true, Nullable: true}, diff --git a/entc/integration/ent/schema/user.go b/entc/integration/ent/schema/user.go index 8a549d37ae..e39f0624a9 100644 --- a/entc/integration/ent/schema/user.go +++ b/entc/integration/ent/schema/user.go @@ -44,7 +44,7 @@ func (User) Fields() []ent.Field { Optional(). Sensitive(), field.Enum("role"). - Values("user", "admin", "free-user"). + Values("user", "admin", "free-user", "test user"). Default("user"), field.String("SSOCert"). Optional(), diff --git a/entc/integration/ent/user/user.go b/entc/integration/ent/user/user.go index 5025a315c2..38bb157ba9 100644 --- a/entc/integration/ent/user/user.go +++ b/entc/integration/ent/user/user.go @@ -185,6 +185,7 @@ const ( RoleUser Role = "user" RoleAdmin Role = "admin" RoleFreeUser Role = "free-user" + RoleTestUser Role = "test user" ) func (r Role) String() string { @@ -194,7 +195,7 @@ func (r Role) String() string { // RoleValidator is a validator for the "role" field enum values. It is called by the builders before save. func RoleValidator(r Role) error { switch r { - case RoleUser, RoleAdmin, RoleFreeUser: + case RoleUser, RoleAdmin, RoleFreeUser, RoleTestUser: return nil default: return fmt.Errorf("user: invalid enum value for role field: %q", r) diff --git a/entc/integration/gremlin/ent/user/user.go b/entc/integration/gremlin/ent/user/user.go index 34f2b247fc..ad54d8829c 100644 --- a/entc/integration/gremlin/ent/user/user.go +++ b/entc/integration/gremlin/ent/user/user.go @@ -101,6 +101,7 @@ const ( RoleUser Role = "user" RoleAdmin Role = "admin" RoleFreeUser Role = "free-user" + RoleTestUser Role = "test user" ) func (r Role) String() string { @@ -110,7 +111,7 @@ func (r Role) String() string { // RoleValidator is a validator for the "role" field enum values. It is called by the builders before save. func RoleValidator(r Role) error { switch r { - case RoleUser, RoleAdmin, RoleFreeUser: + case RoleUser, RoleAdmin, RoleFreeUser, RoleTestUser: return nil default: return fmt.Errorf("user: invalid enum value for role field: %q", r)