8000 [FIX] Error when using different primary container name in driver and executor by machichima · Pull Request #6363 · flyteorg/flyte · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[FIX] Error when using different primary container name in driver and executor #6363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
const defaultContainerTemplateName = "default"
const defaultInitContainerTemplateName = "default-init"
const primaryContainerTemplateName = "primary"
const primaryInitContainerTemplateName = "primary-init"
const PrimaryInitContainerTemplateName = "primary-init"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are reserved container names, I don't think they should be exported and used in the spark plugin logic

const PrimaryContainerKey = "primary_container_name"

// AddRequiredNodeSelectorRequirements adds the provided v1.NodeSelectorRequirement
Expand Down Expand Up @@ -716,7 +716,7 @@
for i := 0; i < len(templatePodSpec.InitContainers); i++ {
if templatePodSpec.InitContainers[i].Name == defaultInitContainerTemplateName {
defaultInitContainerTemplate = &templatePodSpec.InitContainers[i]
} else if templatePodSpec.InitContainers[i].Name == primaryInitContainerTemplateName {
} else if templatePodSpec.InitContainers[i].Name == PrimaryInitContainerTemplateName {
primaryInitContainerTemplate = &templatePodSpec.InitContainers[i]
}
}
Expand Down Expand Up @@ -746,20 +746,20 @@
< 8000 /td> return nil, err
}
}
}

// Check for any name matching template containers
for _, templateContainer := range templatePodSpec.Containers {
if templateContainer.Name != container.Name {
continue
}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems okay to me. Can we have the same behavior for init containers to keep it consistent?

// Check for any name matching template containers
for _, templateContainer := range templatePodSpec.Containers {
if templateContainer.Name != container.Name {
continue
}

if mergedContainer == nil {
mergedContainer = &templateContainer
} else {
err := mergo.Merge(mergedContainer, templateContainer, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
if mergedContainer == nil {
mergedContainer = &templateContainer
} else {
err := mergo.Merge(mergedContainer, templateContainer, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}

Check warning on line 762 in flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go#L759-L762

Added lines 8000 #L759 - L762 were not covered by tests
}
}
}
Expand Down Expand Up @@ -799,20 +799,20 @@
return nil, err
}
}
}

// Check for any name matching template containers
for _, templateInitContainer := range templatePodSpec.InitContainers {
if templateInitContainer.Name != initContainer.Name {
continue
}
} else {
// Check for any name matching template containers
for _, templateInitContainer := range templatePodSpec.InitContainers {
if templateInitContainer.Name != initContainer.Name {
continue
}

if mergedInitContainer == nil {
mergedInitContainer = &templateInitContainer
} else {
err := mergo.Merge(mergedInitContainer, templateInitContainer, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
if mergedInitContainer == nil {
mergedInitContainer = &templateInitContainer
} else {
err := mergo.Merge(mergedInitContainer, templateInitContainer, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}

Check warning on line 815 in flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go

View check run for this annotation

Codecov / codecov/patch

flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go#L812-L815

Added lines #L812 - L815 were not covered by tests
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ func TestMergeWithBasePodTemplate(t *testing.T) {
}

primaryInitContainerTemplate := v1.Container{
Name: primaryInitContainerTemplateName,
Name: PrimaryInitContainerTemplateName,
TerminationMessagePath: "/dev/primary-init-termination-log",
}

Expand Down
18 changes: 10 additions & 8 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont

driverPod := sparkJob.GetDriverPod()
if driverPod != nil {
if driverPod.GetPrimaryContainerName() != "" {
primaryContainerName = driverPod.GetPrimaryContainerName()
}

if driverPod.GetPodSpec() != nil {
var customPodSpec *v1.PodSpec

Expand All @@ -209,15 +213,12 @@ func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCont
"Unable to unmarshal driver pod spec [%v], Err: [%v]", driverPod.GetPodSpec(), err.Error())
}

podSpec, err = flytek8s.MergeOverlayPodSpecOntoBase(podSpec, customPodSpec)
podSpec, err = flytek8s.MergeBasePodSpecOntoTemplate(podSpec, customPodSpec, primaryContainerName, flytek8s.PrimaryInitContainerTemplateName)
if err != nil {
return nil, err
}
}

if driverPod.GetPrimaryContainerName() != "" {
primaryContainerName = driverPod.GetPrimaryContainerName()
}
}

primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName)
Expand Down Expand Up @@ -253,6 +254,10 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo

executorPod := sparkJob.GetExecutorPod()
if executorPod != nil {
if executorPod.GetPrimaryContainerName() != "" {
primaryContainerName = executorPod.GetPrimaryContainerName()
}

if executorPod.GetPodSpec() != nil {
var customPodSpec *v1.PodSpec

Expand All @@ -262,14 +267,11 @@ func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionCo
"Unable to unmarshal executor pod spec [%v], Err: [%v]", executorPod.GetPodSpec(), err.Error())
}

podSpec, err = flytek8s.MergeOverlayPodSpecOntoBase(podSpec, customPodSpec)
podSpec, err = flytek8s.MergeBasePodSpecOntoTemplate(podSpec, customPodSpec, primaryContainerName, flytek8s.PrimaryInitContainerTemplateName)
if err != nil {
return nil, err
}
}
if executorPod.GetPrimaryContainerName() != "" {
primaryContainerName = executorPod.GetPrimaryContainerName()
}
}

primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName)
Expand Down
102 changes: 102 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,108 @@ func TestBuildResourceCustomK8SPod(t *testing.T) {
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
}

// Ensure setting different primary name for driver and executor works
func TestBuildResourceCustomK8SPodChangePrimaryContainerName(t *testing.T) {

defaultConfig := defaultPluginConfig()
assert.NoError(t, config.SetK8sPluginConfig(defaultConfig))

basePodSpec := dummyPodSpec()
basePodSpec.NodeSelector = map[string]string{"x/custom": "foo"}

// pod for driver and executor
driverPodSpec := &corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "driver-primary",
Image: testImage,
Args: testArgs,
Env: dummyEnvVarsWithSecretRef,
},
},
}
executorPodSpec := &corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "executor-primary",
Image: testImage,
Args: testArgs,
Env: dummyEnvVarsWithSecretRef,
},
},
}

driverK8SPod := &core.K8SPod{
PodSpec: transformStructToStructPB(t, driverPodSpec),
PrimaryContainerName: "driver-primary",
}
executorK8SPod := &core.K8SPod{
PodSpec: transformStructToStructPB(t, executorPodSpec),
PrimaryContainerName: "executor-primary",
}

taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod, basePodSpec)
sparkResourceHandler := sparkResourceHandler{}

taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{})
resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx)

assert.Nil(t, err)
assert.NotNil(t, resource)
sparkApp, ok := resource.(*sj.SparkApplication)
assert.True(t, ok)

// Application
assert.Equal(t, v1.TypeMeta{
Kind: KindSparkApplication,
APIVersion: sparkOp.SchemeGroupVersion.String(),
}, sparkApp.TypeMeta)

// Application spec
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount)
assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type)
assert.Equal(t, testImage, *sparkApp.Spec.Image)
assert.Equal(t, append(testArgs, testArgs...), sparkApp.Spec.Arguments)
assert.Equal(t, sparkOp.RestartPolicy{
Type: sparkOp.OnFailure,
OnSubmissionFailureRetries: intPtr(int32(14)),
}, sparkApp.Spec.RestartPolicy)
assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass)
assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile)

// Driver
assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1)
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET"))
assert.Equal(t, 11, len(sparkApp.Spec.Driver.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image)
assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig)
assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName)
cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32)
assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores)
assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory)

// Executor
assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value)
assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value)
assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET"))
assert.Equal(t, 11, len(sparkApp.Spec.Executor.Env))
assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image)
assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt)
assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig)
assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork)
assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName)
cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32)
instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32)
assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances)
assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores)
assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory)
}

func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
Expand Down
Loading
0