8000 Add streaming support for zero shot inference by arnavgarg1 · Pull Request #3878 · ludwig-ai/ludwig · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add streaming support for zero shot inference #3878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 87 additions & 19 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.fs_utils import makedirs, path_exists, upload_output_directory
from ludwig.utils.heuristics import get_auto_learning_rate
from ludwig.utils.llm_utils import create_text_streamer, TextStreamer
from ludwig.utils.misc_utils import (
get_commit_hash,
get_file_names,
Expand Down Expand Up @@ -966,8 +967,18 @@ def generate(
self,
input_strings: Union[str, List[str]],
generation_config: Optional[dict] = None,
streaming: Optional[bool] = False,
) -> Union[str, List[str]]:
"""A simple generate() method that directly uses the underlying transformers library to generate text."""
"""A simple generate() method that directly uses the underlying transformers library to generate text.

Args:
input_strings (Union[str, List[str]]): Input text or list of texts to generate from.
generation_config (Optional[dict]): Configuration for text generation.
streaming (Optional[bool]): If True, enable streaming output.

Returns:
Union[str, List[str]]: Generated text or list of generated texts.
"""
if self.config_obj.model_type != MODEL_LLM:
raise ValueError(
f"Model type {self.config_obj.model_type} is not supported by this method. Only `llm` model type is "
Expand All @@ -983,31 +994,88 @@ def generate(
# warning, e.g.:
# "A decoder-only architecture is being used, but right-padding was detected! For correct generation results, "
# "please set `padding_side='left'` when initializing the tokenizer.
if not self.model.model.config.is_encoder_decoder:
padding_side = "left"
else:
padding_side = "right"
padding_side = "left" if not self.model.model.config.is_encoder_decoder else "right"
tokenizer = HFTokenizer(self.config_obj.base_model, padding_side=padding_side)

with self.model.use_generation_config(generation_config):
start_time = time.time()
inputs = tokenizer.tokenizer(input_strings, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to("cuda")
attention_mask = inputs["attention_mask"].to("cuda")
tokenized_inputs = tokenizer.tokenizer(input_strings, return_tensors="pt", padding=True)
input_ids = tokenized_inputs["input_ids"].to("cuda")
attention_mask = tokenized_inputs["attention_mask"].to("cuda")

if streaming:
streamer = create_text_streamer(tokenizer.tokenizer)
outputs = self._generate_streaming_outputs(input_strings, input_ids, attention_mask, streamer)
else:
outputs = self._generate_non_streaming_outputs(input_strings, input_ids, attention_mask)

decoded_outputs = tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
logger.info(f"Finished generating in: {(time.time() - start_time):.2f}s.")

return decoded_outputs[0] if len(decoded_outputs) == 1 else decoded_outputs

def _generate_streaming_outputs(
self,
input_strings: Union[str, List[str]],
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
streamer: TextStreamer,
) -> torch.Tensor:
"""Generate streaming outputs for the given input.

Args:
input_strings (Union[str, List[str]]): Input text or list of texts to generate from.
input_ids (torch.Tensor): Tensor containing input IDs.
attention_mask (torch.Tensor): Tensor containing attention masks.
streamer (Union[TextStreamer, None]): Text streamer instance for streaming output.

Returns:
torch.Tensor: Concatenated tensor of generated outputs.
"""
outputs = []
input_strings = input_strings if isinstance(input_strings, list) else [input_strings]
for i in range(len(input_ids)):
with torch.no_grad():
outputs = self.model.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
# NOTE: self.model.model.generation_config is not used here because it is the default
# generation config that the CausalLM was initialized with, rather than the one set within the
# context manager.
logger.info(f"Input: {input_strings[i]}\n")
# NOTE: self.model.model.generation_config is not used here because it is the default
# generation config that the CausalLM was initialized with, rather than the one set within the
# context manager.
generated_output = self.model.model.generate(
input_ids=input_ids[i].unsqueeze(0),
attention_mask=attention_mask[i].unsqueeze(0),
generation_config=self.model.generation,
streamer=streamer,
)
decoded_outputs = tokenizer.tokenizer.batch_decode(outputs, skip_special_tokens=True)
logger.info(f"Finished generating in: {(time.time() - start_time):.2f}s.")
if len(decoded_outputs) == 1:
return decoded_outputs[0]
return decoded_outputs
logger.info("----------------------")
outputs.append(generated_output)
return torch.cat(outputs, dim=0)

def _generate_non_streaming_outputs(
self,
_input_strings: Union[str, List[str]],
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Generate non-streaming outputs for the given input.

Args:
_input_strings (Union[str, List[str]]): Unused input parameter.
input_ids (torch.Tensor): Tensor containing input IDs.
attention_mask (torch.Tensor): Tensor containing attention masks.
streamer (Union[TextStreamer, None]): Text streamer instance for streaming output.

Returns:
torch.Tensor: Tensor of generated outputs.
"""
with torch.no_grad():
# NOTE: self.model.model.generation_config is not used here because it is the default
# generation config that the CausalLM was initialized with, rather than the one set within the
# context manager.
return self.model.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=self.model.generation,
)

def predict(
self,
Expand Down
7 changes: 6 additions & 1 deletion ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import transformers
from bitsandbytes.nn.modules import Embedding
from packaging import version
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, TextStreamer

from ludwig.constants import IGNORE_INDEX_TOKEN_ID, LOGITS, PREDICTIONS, PROBABILITIES
from ludwig.schema.trainer import LLMTrainerConfig
Expand Down Expand Up @@ -622,3 +622,8 @@ def update_embedding_layer(model: AutoModelForCausalLM, config_obj: LLMTrainerCo
logger.info("Updated the pretrained embedding layer to use the embedding layer from bitsandbytes.")

return model


def create_text_streamer(tokenizer: PreTrainedTokenizer) -> TextStreamer:
"""Creates a TextStreamer object for streaming text to stdout during generation."""
return TextStreamer(tokenizer=tokenizer, skip_prompt=True)
0