8000 Add —config to allow the user to specify the config by 8W9aG · Pull Request #2291 · replicate/cog · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add —config to allow the user to specify the config #2291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pkg/cli/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var buildStrip bool
var buildPrecompile bool
var buildFast bool
var buildLocalImage bool
var configFilename string

const useCogBaseImageFlagKey = "use-cog-base-image"

Expand All @@ -53,6 +54,7 @@ func newBuildCommand() *cobra.Command {
addPrecompileFlag(cmd)
addFastFlag(cmd)
addLocalImage(cmd)
addConfigFlag(cmd)
cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'")
return cmd
}
Expand All @@ -68,7 +70,7 @@ func buildCommand(cmd *cobra.Command, args []string) error {
logClient := coglog.NewClient(client)
logCtx := logClient.StartBuild(buildFast, buildLocalImage)

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
logClient.EndBuild(ctx, err, logCtx)
return err
Expand Down Expand Up @@ -172,6 +174,11 @@ func addLocalImage(cmd *cobra.Command) {
_ = cmd.Flags().MarkHidden(localImage)
}

func addConfigFlag(cmd *cobra.Command) {
const configFlag = "f"
cmd.Flags().StringVar(&configFilename, configFlag, "cog.yaml", "The name of the config file.")
}

func checkMutuallyExclusiveFlags(cmd *cobra.Command, args []string) error {
flags := []string{useCogBaseImageFlagKey, "use-cuda-base-image", "dockerfile"}
var flagsSet []string
Expand Down
6 changes: 3 additions & 3 deletions pkg/cli/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/dockerfile"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/util/console"
)

Expand All @@ -18,7 +17,7 @@ func newDebugCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "debug",
Hidden: true,
Short: "Generate a Dockerfile from " + global.ConfigFilename,
Short: "Generate a Dockerfile from cog",
RunE: cmdDockerfile,
}

Expand All @@ -29,6 +28,7 @@ func newDebugCommand() *cobra.Command {
addBuildTimestampFlag(cmd)
addFastFlag(cmd)
addLocalImage(cmd)
addConfigFlag(cmd)
cmd.Flags().StringVarP(&imageName, "image-name", "", "", "The image name to use for the generated Dockerfile")

return cmd
Expand All @@ -37,7 +37,7 @@ func newDebugCommand() *cobra.Command {
func cmdDockerfile(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ This will attempt to migrate your cog project to be compatible with fast boots.`
}

addYesFlag(cmd)
addConfigFlag(cmd)

return cmd
}
Expand All @@ -31,7 +32,7 @@ func cmdMigrate(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
err = migrator.Migrate(ctx)
err = migrator.Migrate(ctx, configFilename)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ the prediction on that.`,
addSetupTimeoutFlag(cmd)
addFastFlag(cmd)
addLocalImage(cmd)
addConfigFlag(cmd)

cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg")
cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path")
Expand All @@ -78,7 +79,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
// Build image

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func newPushCommand() *cobra.Command {
addPrecompileFlag(cmd)
addFastFlag(cmd)
addLocalImage(cmd)
addConfigFlag(cmd)

return cmd
}
Expand All @@ -54,7 +55,7 @@ func push(cmd *cobra.Command, args []string) error {
logClient := coglog.NewClient(client)
logCtx := logClient.StartPush(buildFast, buildLocalImage)

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
logClient.EndPush(ctx, err, logCtx)
return err
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func newRunCommand() *cobra.Command {
addGpusFlag(cmd)
addFastFlag(cmd)
addLocalImage(cmd)
addConfigFlag(cmd)

flags := cmd.Flags()
// Flags after first argument are considered args and passed to command
Expand All @@ -54,7 +55,7 @@ func newRunCommand() *cobra.Command {
func run(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.`
addUseCogBaseImageFlag(cmd)
addGpusFlag(cmd)
addFastFlag(cmd)
addConfigFlag(cmd)

cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen")

Expand All @@ -43,7 +44,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.`
func cmdServe(cmd *cobra.Command, arg []string) error {
ctx := cmd.Context()

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/train.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Otherwise, it will build the model in the current directory and train it.`,
addGpusFlag(cmd)
addUseCogBaseImageFlag(cmd)
addFastFlag(cmd)
addConfigFlag(cmd)

cmd.Flags().StringArrayVarP(&trainInputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg")
cmd.Flags().StringArrayVarP(&trainEnvFlags, "env", "e", []string{}, "Environment variables, in the form name=value")
Expand All @@ -61,7 +62,7 @@ func cmdTrain(cmd *cobra.Command, args []string) error {
volumes := []docker.Volume{}
gpus := gpusFlag

cfg, projectDir, err := config.GetConfig()
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
return err
}
Expand Down
26 changes: 13 additions & 13 deletions pkg/config/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,30 @@ import (
"path/filepath"

"github.com/replicate/cog/pkg/errors"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/util/files"
)

const maxSearchDepth = 100

// Returns the project's root directory, or the directory specified by the --project-dir flag
func GetProjectDir() (string, error) {
func GetProjectDir(configFilename string) (string, error) {
cwd, err := os.Getwd()
if err != nil {
return "", err
}
return findProjectRootDir(cwd)
return findProjectRootDir(cwd, configFilename)
}

// Loads and instantiates a Config object
// customDir can be specified to override the default - current working directory
func GetConfig() (*Config, string, error) {
func GetConfig(configFilename string) (*Config, string, error) {
// Find the root project directory
rootDir, err := GetProjectDir()
rootDir, err := GetProjectDir(configFilename)

if err != nil {
return nil, "", err
}
configPath := path.Join(rootDir, global.ConfigFilename)
configPath := path.Join(rootDir, configFilename)

// Then try to load the config file from there
config, err := loadConfigFromFile(configPath)
Expand All @@ -51,7 +51,7 @@ func loadConfigFromFile(file string) (*Config, error) {
}

if !exists {
return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", global.ConfigFilename, filepath.Dir(file))
return nil, fmt.Errorf("%s does not exist in %s. Are you in the right directory?", filepath.Base(file), filepath.Dir(file))
}

contents, err := os.ReadFile(file)
Expand All @@ -69,30 +69,30 @@ func loadConfigFromFile(file string) (*Config, error) {
}

// Given a directory, find the cog config file in that directory
func findConfigPathInDirectory(dir string) (configPath string, err error) {
filePath := path.Join(dir, global.ConfigFilename)
func findConfigPathInDirectory(dir string, configFilename string) (configPath string, err error) {
filePath := path.Join(dir, configFilename)
exists, err := files.Exists(filePath)
if err != nil {
return "", fmt.Errorf("Failed to scan directory %s for %s: %s", dir, filePath, err)
} else if exists {
return filePath, nil
}

return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", global.ConfigFilename, dir))
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s", configFilename, dir))
}

// Walk up the directory tree to find the root of the project.
// The project root is defined as the directory housing a `cog.yaml` file.
func findProjectRootDir(startDir string) (string, error) {
func findProjectRootDir(startDir string, configFilename string) (string, error) {
dir := startDir
for i := 0; i < maxSearchDepth; i++ {
switch _, err := findConfigPathInDirectory(dir); {
switch _, err := findConfigPathInDirectory(dir, configFilename); {
case err != nil && !errors.IsConfigNotFound(err):
return "", err
case err == nil:
return dir, nil
case dir == "." || dir == "/":
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", global.ConfigFilename, startDir))
return "", errors.ConfigNotFound(fmt.Sprintf("%s not found in %s (or in any parent directories)", configFilename, startDir))
}

dir = filepath.Dir(dir)
Expand Down
4 changes: 2 additions & 2 deletions pkg/config/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestFindProjectRootDirShouldFindParentDir(t *testing.T) {
err = os.MkdirAll(subdir, 0o700)
require.NoError(t, err)

foundDir, err := findProjectRootDir(subdir)
foundDir, err := findProjectRootDir(subdir, "cog.yaml")
require.NoError(t, err)
require.Equal(t, foundDir, projectDir)
}
Expand All @@ -40,6 +40,6 @@ func TestFindProjectRootDirShouldReturnErrIfNoConfig(t *testing.T) {
err := os.MkdirAll(subdir, 0o700)
require.NoError(t, err)

_, err = findProjectRootDir(subdir)
_, err = findProjectRootDir(subdir, "cog.yaml")
require.Error(t, err)
}
1 change: 0 additions & 1 deletion pkg/global/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ var (
BuildTime = "none"
Debug = false
ProfilingEnabled = false
ConfigFilename = "cog.yaml"
ReplicateRegistryHost = "r8.im"
ReplicateWebsiteHost = "replicate.com"
LabelNamespace = "run.cog."
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrate/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ package migrate
import "context"

type Migrator interface {
Migrate(ctx context.Context) error
Migrate(ctx context.Context, configFilename string) error
}
11 changes: 5 additions & 6 deletions pkg/migrate/migrator_v1_v1fast.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/dockerfile"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/requirements"
"github.com/replicate/cog/pkg/util"
"github.com/replicate/cog/pkg/util/console"
Expand All @@ -40,8 +39,8 @@ func NewMigratorV1ToV1Fast(interactive bool) *MigratorV1ToV1Fast {
}
}

func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error {
cfg, projectDir, err := config.GetConfig()
func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context, configFilename string) error {
cfg, projectDir, err := config.GetConfig(configFilename)
if err != nil {
return err
}
Expand All @@ -57,7 +56,7 @@ func (g *MigratorV1ToV1Fast) Migrate(ctx context.Context) error {
if err != nil {
return err
}
err = g.flushConfig(cfg, projectDir)
err = g.flushConfig(cfg, projectDir, configFilename)
return err
}

Expand Down Expand Up @@ -167,7 +166,7 @@ func (g *MigratorV1ToV1Fast) checkPythonCode(ctx context.Context, cfg *config.Co
return nil
}

func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error {
func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string, configFilename string) error {
if cfg.Build == nil {
cfg.Build = config.DefaultConfig().Build
}
Expand All @@ -182,7 +181,7 @@ func (g *MigratorV1ToV1Fast) flushConfig(cfg *config.Config, dir string) error {
}
configStr := string(data)

configFilepath := filepath.Join(dir, global.ConfigFilename)
configFilepath := filepath.Join(dir, configFilename)
file, err := os.Open(configFilepath)
if err != nil {
return err
Expand Down
5 changes: 2 additions & 3 deletions pkg/migrate/migrator_v1_v1fast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/stretchr/testify/require"

"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/requirements"
)

Expand All @@ -25,7 +24,7 @@ func TestMigrate(t *testing.T) {
require.NoError(t, err)

// Write our test configs/code
configFilepath := filepath.Join(dir, global.ConfigFilename)
configFilepath := filepath.Join(dir, "cog.yaml")
file, err := os.Create(configFilepath)
require.NoError(t, err)
_, err = file.WriteString(`build:
Expand Down Expand Up @@ -56,7 +55,7 @@ class Predictor(BasePredictor):

// Perform the migration
migrator := NewMigratorV1ToV1Fast(false)
err = migrator.Migrate(t.Context())
err = migrator.Migrate(t.Context(), "cog.yaml")
require.NoError(t, err)

// Check config output
Expand Down
0