[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core]: Support destroying all KV cache during runtime #10810

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

HollowMan6
Copy link
Contributor
@HollowMan6 HollowMan6 commented Dec 1, 2024

Implements #10714

API Design:

  • Destroy (this PR implements): vllm.LLM().llm_engine._destroy_kv_caches()
  • ReInitialize (already have): vllm.LLM().llm_engine._initialize_kv_caches()
  • Stop loop (already have): vllm.LLM().llm_engine.model_executor.stop_remote_worker_execution_loop()

This PR only implements _destroy_kv_caches for GPU executor and workers, as I don’t have other available hardware, feel free to take over this PR to implement others, and once we finish all the implementations, we can make destroy_cache() an abstract method.

Also, since the engine won’t generate without KV Caches (will throw errors), this PR assumes that the developers will handle everything on their side so that no request will be sent to generate after _destroy_kv_caches() and before _initialize_kv_caches() (in sleep mode)

Code for testing:

import ray, time
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

@ray.remote
class LLMRayActor:
    def __init__(self, *args, **kwargs):
        import vllm

        if not kwargs["tensor_parallel_size"] == 1:
            kwargs["worker_use_ray"] = True

        self.llm = vllm.LLM(*args, **kwargs)

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

    def destroy_cache(self):
        self.stop_remote_worker_execution_loop()
        self.llm.llm_engine._destroy_kv_caches()

    def load_cache(self):
        self.stop_remote_worker_execution_loop()
        self.llm.llm_engine._initialize_kv_caches()

    def stop_remote_worker_execution_loop(self):
        self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()

def create_vllm_engines(
    num_engines: int,
    tensor_parallel_size: int,
    model: str,
):
    vllm_engines = []
    for _ in range(num_engines):
        num_gpus = int(tensor_parallel_size == 1)
        scheduling_strategy = None

        if tensor_parallel_size > 1:
            bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
            pg = placement_group(bundles)
            ray.get(pg.ready())

            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
            )

        vllm_engines.append(
            LLMRayActor.options(
                num_cpus=1,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
            ).remote(
                model,
                tensor_parallel_size=tensor_parallel_size,
            )
        )

    return vllm_engines

if __name__ == "__main__":
    # engines = create_vllm_engines(2, 2, "meta-llama/Llama-3.1-8B-Instruct")
    engines = create_vllm_engines(4, 1, "meta-llama/Llama-3.1-8B-Instruct")

    ref = []
    for engine in engines:
        ref.append(engine.generate.remote("San Francisco is a"))
    print(f"output: {ray.get(ref)}")

    ref = []
    for engine in engines:
        ref.append(engine.destroy_cache.remote())
    ray.get(ref)

    time.sleep(5)

    ref = []
    for engine in engines:
        ref.append(engine.load_cache.remote())
    ray.get(ref)

    ref = []
    for engine in engines:
        ref.append(engine.generate.remote("New York is a"))
    print(f"output: {ray.get(ref)}")

Copy link
github-actions bot commented Dec 1, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Implements vllm-project#10714

API Design:
- Destroy (this PR implements): `vllm.LLM().llm_engine._destroy_kv_caches()`
- ReInitialize (already have): `vllm.LLM().llm_engine._initialize_kv_caches()`
- Stop loop (already have): `vllm.LLM().llm_engine.model_executor.stop_remote_worker_execution_loop()`

This PR only implements `_destroy_kv_caches` for GPU executor and workers, as I don’t have other available hardware, feel free to take over this PR to implement others, and once we finish all the implementations, we can make `destroy_cache()` an abstract method.

Also, since the engine won’t generate without KV Caches (will throw errors), this PR assumes that the developers will handle everything on their side so that no request will be sent to generate after `_destroy_kv_caches()` and before `_initialize_kv_caches()` (in sleep mode)

Code for testing:
```python
import ray, time
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

@ray.remote
class LLMRayActor:
    def __init__(self, *args, **kwargs):
        import vllm

        if not kwargs["tensor_parallel_size"] == 1:
            kwargs["worker_use_ray"] = True

        self.llm = vllm.LLM(*args, **kwargs)

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

    def destroy_cache(self):
        self.stop_remote_worker_execution_loop()
        self.llm.llm_engine._destroy_kv_caches()

    def load_cache(self):
        self.stop_remote_worker_execution_loop()
        self.llm.llm_engine._initialize_kv_caches()

    def stop_remote_worker_execution_loop(self):
        self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop()

def create_vllm_engines(
    num_engines: int,
    tensor_parallel_size: int,
    model: str,
):
    vllm_engines = []
    for _ in range(num_engines):
        num_gpus = int(tensor_parallel_size == 1)
        scheduling_strategy = None

        if tensor_parallel_size > 1:
            bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
            pg = placement_group(bundles)
            ray.get(pg.ready())

            scheduling_strategy = PlacementGroupSchedulingStrategy(
                placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
            )

        vllm_engines.append(
            LLMRayActor.options(
                num_cpus=1,
                num_gpus=num_gpus,
                scheduling_strategy=scheduling_strategy,
            ).remote(
                model,
                tensor_parallel_size=tensor_parallel_size,
            )
        )

    return vllm_engines

if __name__ == "__main__":
    # engines = create_vllm_engines(2, 2, "meta-llama/Llama-3.1-8B-Instruct")
    engines = create_vllm_engines(4, 1, "meta-llama/Llama-3.1-8B-Instruct")

    ref = []
    for engine in engines:
        ref.append(engine.generate.remote("San Francisco is a"))
    print(f"output: {ray.get(ref)}")

    ref = []
    for engine in engines:
        ref.append(engine.destroy_cache.remote())
    ray.get(ref)

    time.sleep(5)

    ref = []
    for engine in engines:
        ref.append(engine.load_cache.remote())
    ray.get(ref)

    ref = []
    for engine in engines:
        ref.append(engine.generate.remote("New York is a"))
    print(f"output: {ray.get(ref)}")
```

Signed-off-by: Hollow Man <hollowman@opensuse.org>
@HollowMan6
Copy link
Contributor Author

Also, since the engine won’t generate without KV Caches (will throw errors), this PR assumes that the developers will handle everything on their side so that no request will be sent to generate after _destroy_kv_caches() and before _initialize_kv_caches() (in sleep mode)

Maybe another possible way to get this handled is to check if we have initialized the KV cache when the engine receives the request, if not, initialize it so that we can get rid of manual intervention. But this didn't get implemented in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant