8000 Create public ssh key is it not exist in `dstack pool add-ssh` by TheBits · Pull Request #1173 · dstackai/dstack · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Create public ssh key is it not exist in dstack pool add-ssh #1173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/dstack/_internal/cli/commands/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dstack._internal.core.models.runs import InstanceStatus, Requirements, get_policy_map
from dstack._internal.utils.common import pretty_date
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import convert_pkcs8_to_pem, generate_public_key, rsa_pkey_from_str
from dstack.api._public.resources import Resources
from dstack.api.utils import load_profile

Expand Down Expand Up @@ -273,7 +274,10 @@ def _add(self, args: argparse.Namespace) -> None:

# TODO(egor-s): user key must be added during the `run`, not `pool add`
user_priv_key = Path("~/.dstack/ssh/id_rsa").expanduser().read_text().strip()
user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip()
try:
user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip()
except FileNotFoundError:
user_pub_key = generate_public_key(rsa_pkey_from_str(user_priv_key))
user_ssh_key = SSHKey(public=user_pub_key, private=user_priv_key)

try:
Expand All @@ -293,19 +297,26 @@ def _add_ssh(self, args: argparse.Namespace) -> None:

try:
# TODO: user key must be added during the `run`, not `pool add`
user_priv_key = Path("~/.dstack/ssh/id_rsa").expanduser().read_text().strip()
user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip()
user_priv_key = convert_pkcs8_to_pem(
Path("~/.dstack/ssh/id_rsa").expanduser().read_text().strip()
)
try:
user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip()
except FileNotFoundError:
user_pub_key = generate_public_key(rsa_pkey_from_str(user_priv_key))
user_ssh_key = SSHKey(public=user_pub_key, private=user_priv_key)
ssh_keys.append(user_ssh_key)
except OSError:
pass

if args.ssh_identity_file:
try:
ssh_key = SSHKey(
public=args.ssh_identity_file.with_suffix(".pub").read_text(),
private=args.ssh_identity_file.read_text(),
)
private_key = convert_pkcs8_to_pem(args.ssh_identity_file.read_text())
try:
pub_key = args.ssh_identity_file.with_suffix(".pub").read_text()
except FileNotFoundError:
pub_key = generate_public_key(rsa_pkey_from_str(private_key))
ssh_key = SSHKey(public=pub_key, private=private_key)
ssh_keys.append(ssh_key)
except OSError:
console.print("[error]Unable to read the public key.[/]")
Expand Down
42 changes: 27 additions & 15 deletions src/dstack/_internal/core/backends/remote/provisioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,32 @@ def host_info_to_instance_type(host_info: Dict[str, Any]) -> InstanceType:

@contextmanager
def get_paramiko_connection(
ssh_user: str, host: str, port: int, pkey: paramiko.PKey
ssh_user: str, host: str, port: int, pkeys: List[paramiko.PKey]
) -> Generator[paramiko.SSHClient, None, None]:
try:
with paramiko.SSHClient() as client:
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(
username=ssh_user,
hostname=host,
port=port,
pkey=pkey,
look_for_keys=False,
allow_agent=False,
timeout=SSH_CONNECT_TIMEOUT,
with paramiko.SSHClient() as client:
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
for pkey in pkeys:
conn_url = f"{ssh_user}@{host}:{port}"
try:
logger.debug("Try to connect to %s with key %s", conn_url, pkey.fingerprint)
client.connect(
username=ssh_user,
hostname=host,
port=port,
pkey=pkey,
look_for_keys=False,
allow_agent=False,
timeout=SSH_CONNECT_TIMEOUT,
)
except paramiko.AuthenticationException:
continue # try next key
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
else:
yield client
return
else:
keys_fp = ", ".join(f"{pk.fingerprint!r}" for pk in pkeys)
raise ProvisioningError(
f"SSH connection to the {conn_url} user with keys [{keys_fp}] was unsuccessful"
)
yield client
except (paramiko.SSHException, OSError) as e:
raise ProvisioningError() from e
22 changes: 13 additions & 9 deletions src/dstack/_internal/server/background/tasks/process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from dstack._internal.utils.common import get_current_datetime
from dstack._internal.utils.logging import get_logger
from dstack._internal.utils.ssh import (
convert_pkcs8_to_pem,
rsa_pkey_from_str,
)

Expand Down Expand Up @@ -127,6 +126,10 @@ async def add_remote(instance_id: UUID) -> None:
)
).one()

if instance.status == InstanceStatus.PENDING:
instance.status = InstanceStatus.PROVISIONING
await session.commit()

retry_duration_deadline = instance.created_at.replace(
tzinfo=datetime.timezone.utc
) + timedelta(seconds=PROVISIONING_TIMEOUT_SECONDS)
Expand All @@ -153,17 +156,17 @@ async def add_remote(instance_id: UUID) -> None:
)

# Prepare connection key
try:
private_string = [
sk.private for sk in remote_details.ssh_keys if sk.private is not None
][0]
except IndexError:
pkeys = [
rsa_pkey_from_str(sk.private)
for sk in remote_details.ssh_keys
if sk.private is not None
]
if not pkeys:
logger.error("There are no ssh private key")
raise ConfigurationError("The SSH private key is not provided")
pkey = rsa_pkey_from_str(convert_pkcs8_to_pem(private_string))

with get_paramiko_connection(
remote_details.ssh_user, remote_details.host, remote_details.port, pkey
remote_details.ssh_user, remote_details.host, remote_details.port, pkeys
) as client:
logger.info(f"connected to {remote_details.ssh_user} {remote_details.host}")

Expand Down Expand Up @@ -192,7 +195,8 @@ async def add_remote(instance_id: UUID) -> None:
host_info = get_host_info(client, DSTACK_WORKING_DIR)
logger.debug("Received a host_info %s", host_info)

except ProvisioningError:
except ProvisioningError as e:
logger.warning("Provisioning could not be completed because of the error: %s", e)
instance.last_retry_at = get_current_datetime()
await session.commit()
return
Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def convert_pkcs8_to_pem(private_string: str) -> str:
capture_output=True,
text=True,
)
except FileNotFoundError:
logger.error("Use a PEM key or install ssh-keygen to convert it automatically")
except subprocess.CalledProcessError as e:
logger.error("Fail to convert ssh key: stdout=%s, stderr=%s", e.stdout, e.stderr)

Expand All @@ -150,3 +152,8 @@ def rsa_pkey_from_str(private_string: str) -> PKey:
pkey = paramiko.RSAKey.from_private_key(key_file)
key_file.close()
return pkey


def generate_public_key(private_key: PKey) -> str:
public_key = private_key.get_base64()
return public_key
Loading
0