8000 Add request timeouts to prevent code freezes by jvstme · Pull Request #1140 · dstackai/dstack · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add request timeouts to prevent code freezes #1140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ target-version = "py38"
line-length = 99

[lint]
select = ['E', 'F', 'I' ,'Q', 'W', 'PGH', 'FLY']
select = ['E', 'F', 'I' ,'Q', 'W', 'PGH', 'FLY', 'S113']
ignore =[
'E501',
'E712',
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/cli/utils/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_latest_version() -> Optional[str]:
"anonymous_installation_id": anonymous_installation_id(),
"version": version.__version__,
},
timeout=5,
)
if response.status_code == 200:
return response.text.strip()
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def get_latest_runner_build() -> Optional[str]:
params={
"status": "success",
},
timeout=10,
)
resp.raise_for_status()

Expand All @@ -239,7 +240,7 @@ def get_dstack_gateway_wheel(build: str) -> str:
channel = "release" if settings.DSTACK_RELEASE else "stgn"
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
if build == "latest":
r = requests.get(f"{base_url}/latest-version")
r = requests.get(f"{base_url}/latest-version", timeout=5)
r.raise_for_status()
build = r.text.strip()
logger.debug("Found the latest gateway build: %s", build)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def add_ssh_key(self, name: str, public_key: str) -> List[Dict]:
resp.raise_for_status()

def _make_request(self, method: str, path: str, data: Any = None):
# TODO: set adequate timeout here or in every method
return requests.request(
method=method,
url=API_URL + path,
Expand Down
20 changes: 16 additions & 4 deletions src/dstack/_internal/core/backends/nebius/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

logger = get_logger("nebius")
API_URL = "api.ai.nebius.cloud"
REQUEST_TIMEOUT = 15


class NebiusAPIClient:
Expand Down Expand Up @@ -45,7 +46,7 @@ def get_token(self):
headers={"kid": self.service_account["id"]},
)

resp = requests.post(payload["aud"], json={"jwt": jwt_token})
resp = requests.post(payload["aud"], json={"jwt": jwt_token}, timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
iam_token = resp.json()["iamToken"]
self.s.headers["Authorization"] = f"Bearer {iam_token}"
Expand All @@ -54,7 +55,7 @@ def get_token(self):
def compute_zones_list(self) -> List[dict]:
logger.debug("Fetching compute zones")
self.get_token()
resp = self.s.get(self.url("compute", "/zones"))
resp = self.s.get(self.url("compute", "/zones"), timeout=REQUEST_TIMEOUT)
self.raise_for_status(resp)
return resp.json()["zones"]

Expand All @@ -68,6 +69,7 @@ def resource_manager_folders_create(self, cloud_id: str, name: str, **kwargs) ->
name=name,
**kwargs,
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
return resp.json()
Expand All @@ -82,6 +84,7 @@ def vpc_networks_create(self, folder_id: str, name: str, **kwargs) -> dict:
name=name,
**kwargs,
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
return resp.json()
Expand Down Expand Up @@ -118,6 +121,7 @@ def vpc_subnets_create(
v4CidrBlocks=cird_blocks,
**kwargs,
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
return resp.json()
Expand Down Expand Up @@ -147,6 +151,7 @@ def vpc_security_groups_create(
ruleSpecs=rule_specs,
**kwargs,
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
return resp.json()
Expand All @@ -165,7 +170,9 @@ def vpc_security_groups_list(self, folder_id: str, filter: Optional[str] = None)
def vpc_security_groups_delete(self, security_group_id: str):
logger.debug("Deleting security group %s", security_group_id)
self.get_token()
resp = self.s.delete(self.url("vpc", f"/securityGroups/{security_group_id}"))
resp = self.s.delete(
self.url("vpc", f"/securityGroups/{security_group_id}"), timeout=REQUEST_TIMEOUT
)
self.raise_for_status(resp)

def compute_instances_create(
Expand Down Expand Up @@ -215,6 +222,7 @@ def compute_instances_create(
],
**kwargs,
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
return resp.json()
Expand All @@ -236,7 +244,9 @@ def compute_instances_list(
def compute_instances_delete(self, instance_id: str):
logger.debug("Deleting instance %s", instance_id)
self.get_token()
resp = self.s.delete(self.url("compute", f"/instances/{instance_id}"))
resp = self.s.delete(
self.url("compute", f"/instances/{instance_id}"), timeout=REQUEST_TIMEOUT
)
self.raise_for_status(resp)

def compute_instances_get(self, instance_id: str, full: bool = False) -> dict:
Expand All @@ -247,6 +257,7 @@ def compute_instances_get(self, instance_id: str, full: bool = False) -> dict:
params=dict(
view="FULL" if full else "BASIC",
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
return resp.json()
Expand Down Expand Up @@ -277,6 +288,7 @@ def list(self, service: str, resource: str, params: dict, page_size: int = 1000)
pageToken=page_token,
**params,
),
timeout=REQUEST_TIMEOUT,
)
self.raise_for_status(resp)
data = resp.json()
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/runpod/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def terminate_pod(self, pod_id: str) -> Dict:

def _make_request(self, data: Any = None) -> Response:
try:
# TODO: set adequate timeout here or in every method
response = requests.request(
method="POST",
url=f"{API_URL}?api_key={self.api_key}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, api_key: str, api_token: str):
self.api_url = "https://marketplace.tensordock.com/api/v0".rstrip("/")
self.api_key = api_key
self.api_token = api_token
self.s = requests.Session()
self.s = requests.Session() # TODO: set adequate timeout everywhere the session is used

def auth_test(self) -> bool:
resp = self.s.post(
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/backends/vastai/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class VastAIAPIClient:
def __init__(self, api_key: str):
self.api_url = "https://console.vast.ai/api/v0".rstrip("/")
self.api_key = api_key
self.s = requests.Session()
self.s = requests.Session() # TODO: set adequate timeout everywhere the session is used
retries = Retry(
total=5,
backoff_factor=1,
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/core/services/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_local_repo_credentials(
original_hostname: Optional[str] = None,
) -> RemoteRepoCreds:
url = repo_data.make_url(RepoProtocol.HTTPS) # no auth
r = requests.get(f"{url}/info/refs?service=git-upload-pack")
r = requests.get(f"{url}/info/refs?service=git-upload-pack", timeout=10)
if r.status_code == 200:
return RemoteRepoCreds(protocol=RepoProtocol.HTTPS, private_key=None, oauth_token=None)

Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/gateways/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def complete_service_model(model_info: AnyModel):
def get_tokenizer_config(model_id: str) -> dict:
try:
resp = requests.get(
f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json"
f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json", timeout=10
)
if resp.status_code == 403:
raise ConfigurationError("Private HF models are not supported")
Expand Down
21 changes: 12 additions & 9 deletions src/dstack/_internal/server/services/runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

REMOTE_SHIM_PORT = 10998
REMOTE_RUNNER_PORT = 10999
REQUEST_TIMEOUT = 15


class RunnerClient:
Expand All @@ -31,7 +32,7 @@ def __init__(

def healthcheck(self) -> Optional[HealthcheckResponse]:
try:
resp = requests.get(self._url("/api/healthcheck"))
resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
return HealthcheckResponse.__response__.parse_obj(resp.json())
except requests.exceptions.RequestException:
Expand All @@ -57,26 +58,27 @@ def submit_job(
self._url("/api/submit"),
data=body.json(),
headers={"Content-Type": "application/json"},
timeout=REQUEST_TIMEOUT,
)
resp.raise_for_status()

def upload_code(self, file: Union[BinaryIO, bytes]):
resp = requests.post(self._url("/api/upload_code"), data=file)
resp = requests.post(self._url("/api/upload_code"), data=file, timeout=REQUEST_TIMEOUT)
resp.raise_for_status()

def run_job(self):
resp = requests.post(self._url("/api/run"))
resp = requests.post(self._url("/api/run"), timeout=REQUEST_TIMEOUT)
resp.raise_for_status()

def pull(self, timestamp: int, timeout: int = 5) -> PullResponse:
def pull(self, timestamp: int) -> PullResponse:
resp = requests.get(
self._url("/api/pull"), params={"timestamp": timestamp}, timeout=timeout
self._url("/api/pull"), params={"timestamp": timestamp}, timeout=REQUEST_TIMEOUT
)
resp.raise_for_status()
return PullResponse.__response__.parse_obj(resp.json())

def stop(self):
resp = requests.post(self._url("/api/stop"))
resp = requests.post(self._url("/api/stop"), timeout=REQUEST_TIMEOUT)
resp.raise_for_status()

def _url(self, path: str) -> str:
Expand All @@ -95,7 +97,7 @@ def __init__(

def healthcheck(self, unmask_exeptions: bool = False) -> Optional[HealthcheckResponse]:
try:
resp = requests.get(self._url("/api/healthcheck"), timeout=5)
resp = requests.get(self._url("/api/healthcheck"), timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
return HealthcheckResponse.__response__.parse_obj(resp.json())
except requests.exceptions.RequestException:
Expand All @@ -122,16 +124,17 @@ def submit(
resp = requests.post(
self._url("/api/submit"),
json=post_body,
timeout=REQUEST_TIMEOUT,
)
resp.raise_for_status()

def stop(self, force: bool = False):
body = StopBody(force=force)
resp = requests.post(self._url("/api/stop"), json=body.dict())
resp = requests.post(self._url("/api/stop"), json=body.dict(), timeout=REQUEST_TIMEOUT)
resp.raise_for_status()

def pull(self) -> PullBody:
resp = requests.get(self._url("/api/pull"))
resp = requests.get(self._url("/api/pull"), timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
return PullBody.__response__.parse_obj(resp.json())

Expand Down
1 change: 1 addition & 0 deletions src/dstack/api/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _request(
logger.debug("POST /%s", path)
for _ in range(_MAX_RETRIES):
try:
# TODO: set adequate timeout here or everywhere the method is used
resp = self._s.post(f"{self._base_url}/{path}", **kwargs)
break
except requests.exceptions.ConnectionError as e:
Expand Down
Loading
0