8000 [WIP] sdk/java: implementation of moduleTypeDefs by eunomie · Pull Request #10625 · dagger/dagger · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[WIP] sdk/java: implementation of moduleTypeDefs #10625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8000
4 changes: 4 additions & 0 deletions cmd/codegen/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ type Config struct {
// name is the expected value.
IsInit bool

// TypeDefsOnly indicates whether only type definitions should be generated, excluding other related code artifacts.
// This is used to generate module own's types even if the module doesn't compile.
TypeDefsOnly bool

// ClientOnly indicates that the codegen should only generate the client code.
ClientOnly bool

Expand Down
159 changes: 114 additions & 45 deletions cmd/codegen/generator/go/templates/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,73 +284,87 @@ func (funcs goTemplateFuncs) moduleMainSrc() (string, error) { //nolint: gocyclo
tps, nextTps = nextTps, nil
}

return strings.Join([]string{
fmt.Sprintf("%#v", implementationCode),
mainSrc(funcs.CheckVersionCompatibility),
invokeSrc(objFunctionCases, createMod),
}, "\n"), nil
var out []string
if !funcs.cfg.TypeDefsOnly {
out = append(out, fmt.Sprintf("%#v", implementationCode))
}
out = append(out,
mainSrc(funcs.CheckVersionCompatibility, funcs.cfg.TypeDefsOnly),
registerSrc(createMod),
)
if !funcs.cfg.TypeDefsOnly {
out = append(out, invokeSrc(objFunctionCases))
}
return strings.Join(out, "\n"), nil
}

func dotLine(a *Statement, id string) *Statement {
return a.Op(".").Line().Id(id)
}

const (
parentJSONVar = "parentJSON"
parentNameVar = "parentName"
fnNameVar = "fnName"
inputArgsVar = "inputArgs"
invokeFuncName = "invoke"
parentJSONVar = "parentJSON"
parentNameVar = "parentName"
fnNameVar = "fnName"
inputArgsVar = "inputArgs"
invokeFuncName = "invoke"
registerFuncName = "register"
)

// mainSrc returns the static part of the generated code. It calls out to the
// "invoke" func, which is the mostly dynamically generated code that actually
// calls the user's functions.
func mainSrc(checkVersionCompatibility func(string) bool) string {
func mainSrc(checkVersionCompatibility func(string) bool, typeDefsOnly bool) string {
// Ensure compatibility with modules that predate Void return value handling
voidRet := `err`
if !checkVersionCompatibility("v0.12.0") {
voidRet = `_, err`
}

return `func main() {
ctx := context.Background()

// Direct slog to the new stderr. This is only for dev time debugging, and
// runtime errors/warnings.
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelWarn,
})))
var dispatch string
if typeDefsOnly {
dispatch = `func dispatch(ctx context.Context) (rerr error) {
ctx = telemetry.InitEmbedded(ctx, resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceNameKey.String("dagger-go-sdk"),
// TODO version?
))
defer telemetry.Close()

if err := dispatch(ctx); err != nil {
os.Exit(2)
}
}
// A lot of the "work" actually happens when we're marshalling the return
// value, which entails getting object IDs, which happens in MarshalJSON,
// which has no ctx argument, so we use this lovely global variable.
setMarshalContext(ctx)

func convertError(rerr error) *dagger.Error {
var gqlErr *gqlerror.Error
if errors.As(rerr, &gqlErr) {
dagErr := dag.Error(gqlErr.Message)
if gqlErr.Extensions != nil {
keys := make([]string, 0, len(gqlErr.Extensions))
for k := range gqlErr.Extensions {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
val, err := json.Marshal(gqlErr.Extensions[k])
if err != nil {
fmt.Println("failed to marshal error value:", err)
}
dagErr = dagErr.WithValue(k, dagger.JSON(val))
fnCall := dag.CurrentFunctionCall()
defer func() {
if rerr != nil {
if ` + voidRet + ` := fnCall.ReturnError(ctx, convertError(rerr)); err != nil {
fmt.Println("failed to return error:", err, "\noriginal error:", rerr)
}
}
return dagErr
}()

result, err := register()
if err != nil {
var exec *dagger.ExecError
if errors.As(err, &exec) {
return exec.Unwrap()
}
return err
}
resultBytes, err := json.Marshal(result)
if err != nil {
return fmt.Errorf("marshal: %w", err)
}
return dag.Error(rerr.Error())
}

func dispatch(ctx context.Context) (rerr error) {
if ` + voidRet + ` := fnCall.ReturnValue(ctx, dagger.JSON(resultBytes)); err != nil {
return fmt.Errorf("store return value: %w", err)
}
return nil
}`
} else {
dispatch = `func dispatch(ctx context.Context) (rerr error) {
ctx = telemetry.InitEmbedded(ctx, resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceNameKey.String("dagger-go-sdk"),
Expand Down Expand Up @@ -420,10 +434,50 @@ func dispatch(ctx context.Context) (rerr error) {
}
return nil
}`
}

return `func main() {
ctx := context.Background()

// Direct slog to the new stderr. This is only for dev time debugging, and
// runtime errors/warnings.
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelWarn,
})))

if err := dispatch(ctx); err != nil {
os.Exit(2)
}
}

func convertError(rerr error) *dagger.Error {
var gqlErr *gqlerror.Error
if errors.As(rerr, &gqlErr) {
dagErr := dag.Error(gqlErr.Message)
if gqlErr.Extensions != nil {
keys := make([]string, 0, len(gqlErr.Extensions))
for k := range gqlErr.Extensions {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
val, err := json.Marshal(gqlErr.Extensions[k])
if err != nil {
fmt.Println("failed to marshal error value:", err)
}
dagErr = dagErr.WithValue(k, dagger.JSON(val))
}
}
return dagErr
< 9E7A /td> }
return dag.Error(rerr.Error())
}

` + dispatch
}

// the source code of the invoke func, which is the mostly dynamically generated code that actually calls the user's functions
func invokeSrc(objFunctionCases map[string][]Code, createMod Code) string {
func invokeSrc(objFunctionCases map[string][]Code) string {
// each `case` statement for every object name, which makes up the body of the invoke func
objNames := []string{}
for objName := range objFunctionCases {
Expand All @@ -437,7 +491,7 @@ func invokeSrc(objFunctionCases map[string][]Code, createMod Code) string {
}
// when the object name is empty, return the module definition
objCases = append(objCases, Case(Lit("")).Block(
Return(createMod, Nil()),
Return(Id(registerFuncName).Call()),
))
// default case (return error)
objCases = append(objCases, Default().Block(
Expand Down Expand Up @@ -471,6 +525,21 @@ func invokeSrc(objFunctionCases map[string][]Code, createMod Code) string {
return fmt.Sprintf("%#v", invokeFunc)
}

// the source code of the register func, which exposes the module's defined types
func registerSrc(createMod Code) string {
// func register(
invokeFunc := Func().Id(registerFuncName).Params().Params(
// ) (_ any,
Id("_").Id("any"),
// err error)
Id("err").Error(),
).Block(
Return(createMod, Nil()),
)

return fmt.Sprintf("%#v", invokeFunc)
}

// TODO: use jennifer for generating this magical typedef
func (ps *parseState) renderNameOrStruct(t types.Type) string {
if alias, ok := t.(*types.Alias); ok {
Expand Down
17 changes: 10 additions & 7 deletions cmd/codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ var (

clientOnly bool

isInit bool
isInit bool
typeDefsOnly bool

bundle bool

Expand Down Expand Up @@ -63,6 +64,7 @@ func init() {
rootCmd.Flags().StringVar(&moduleName, "module-name", "", "name of module to generate code for")
rootCmd.Flags().BoolVar(&merge, "merge", false, "merge module deps with project's existing go.mod in a parent directory")
rootCmd.Flags().BoolVar(&isInit, "is-init", false, "whether this command is initializing a new module")
rootCmd.Flags().BoolVar(&typeDefsOnly, "typedefs-only", false, "generate only type definitions (no client code)")
rootCmd.Flags().BoolVar(&clientOnly, "client-only", false, "generate only client code")
rootCmd.Flags().BoolVar(&bundle, "bundle", false, "generate the client in bundle mode")
rootCmd.Flags().StringVar(&moduleSourceID, "module-source-id", "", "id of the module source to generate code for")
Expand All @@ -76,12 +78,13 @@ func ClientGen(cmd *cobra.Command, args []string) error {
ctx = telemetry.InitEmbedded(ctx, nil)

cfg := generator.Config{
Lang: generator.SDKLang(lang),
OutputDir: outputDir,
Merge: merge,
IsInit: isInit,
ClientOnly: clientOnly,
Bundle: bundle,
Lang: generator.SDKLang(lang),
OutputDir: outputDir,
Merge: merge,
IsInit: isInit,
TypeDefsOnly: typeDefsOnly,
ClientOnly: clientOnly,
Bundle: bundle,
}

// If a module source ID is provided or no introspection JSON is provided, we will query
Expand Down
22 changes: 18 additions & 4 deletions core/schema/modulesource.go
Original file line number Diff line number Diff line change
Expand Up @@ -2010,9 +2010,9 @@ func (s *moduleSourceSchema) runModuleDefInSDK(ctx context.Context, src, srcInst
return nil, ErrSDKRuntimeNotImplemented{SDK: src.Self.SDK.Source}
}

// get the runtime container, which is what is exec'd when calling functions in the module
var err error
mod.Runtime, err = runtimeImpl.Runtime(ctx, mod.Deps, srcInstContentHashed)
// get the typedefs container dedicated to get the module's definition.
// this will fall back to the runtime container if `moduleTypeDefs` is not defined.
typeDefs, err := runtimeImpl.TypeDefs(ctx, mod.Deps, srcInstContentHashed)
if err != nil {
return nil, fmt.Errorf("failed to get module runtime: %w", err)
}
Expand Down Expand Up @@ -2048,7 +2048,7 @@ func (s *moduleSourceSchema) runModuleDefInSDK(ctx context.Context, src, srcInst
ctx,
mod,
nil,
mod.Runtime,
typeDefs,
core.NewFunction("", &core.TypeDef{
Kind: core.TypeDefKindObject,
AsObject: dagql.NonNull(core.NewObjectTypeDef("Module", "")),
Expand Down Expand Up @@ -2162,10 +2162,24 @@ func (s *moduleSourceSchema) moduleSourceAsModule(
modName := src.Self.ModuleName

if src.Self.SDKImpl != nil {
runtimeImpl, ok := src.Self.SDKImpl.AsRuntime()
if !ok {
return inst, ErrSDKRuntimeNotImplemented{SDK: src.Self.SDK.Source}
}

mod, err = s.runModuleDefInSDK(ctx, src, srcInstContentHashed, mod)
if err != nil {
return inst, err
}

// pre-load the module Runtime
if mod.Runtime == nil {
mod.Runtime, err = runtimeImpl.Runtime(ctx, mod.Deps, srcInstContentHashed)
if err != nil {
return inst, err
}
}

mod.InstanceID = dagql.CurrentID(ctx)
} else {
// For no SDK, provide an empty stub module definition
Expand Down
34 changes: 34 additions & 0 deletions core/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,40 @@ type Runtime interface {
// Current instance of the module source.
dagql.Instance[*ModuleSource],
) (*Container, error)

/*
HasModuleTypeDefs checks if the module exposes a `moduleTypeDefs` function
to be called by `TypeDefs`.

This doesn't rely on a function exposed by the SDK, but on the list of functions
exposed.
*/
HasModuleTypeDefs() bool

/*
TypeDefs returns a container that is used to execute module code
to retrieve the types defined by the module.

This function will call the following exposed by the SDK:

```gql
moduleTypeDefs(
modSource: ModuleSource!
introspectionJSON: File!
): Container!
```

If this function is not exposed, it will fallback to `Runtime`.
*/
TypeDefs(
context.Context,

// Current module dependencies.
*ModDeps,

// Current instance of the module source.
dagql.Instance[*ModuleSource],
) (*Container, error)
}

/*
Expand Down
1 change: 1 addition & 0 deletions core/sdk/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ var sdkFunctions = []string{
"withConfig",
"codegen",
"moduleRuntime",
"moduleTypeDefs",
"requiredClientGenerationFiles",
"generateClient",
}
Loading
Loading
0