diff --git a/Makefile b/Makefile index a8c5c631..a27707fa 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ tests: tests-basic: poetry run pytest tests/test_basic.py + poetry run pytest tests/test_api.py lint: poetry run ruff check docetl/* --fix diff --git a/README.md b/README.md index 8d603a54..6d3cbc94 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # DocETL: Powering Complex Document Processing Pipelines -[Website (Includes Demo)](https://docetl.com) | [Documentation](https://ucbepic.github.io/docetl) | [Discord](https://discord.gg/fHp7B2X3xx) | Paper (coming soon!) +[Website (Includes Demo)](https://docetl.com) | [Documentation](https://ucbepic.github.io/docetl) | [Discord](https://discord.gg/fHp7B2X3xx) | [NotebookLM Podcast](https://notebooklm.google.com/notebook/ef73248b-5a43-49cd-9976-432d20f9fa4f/audio?pli=1) (thanks Shabie from our Discord community!) | Paper (coming soon!) ![DocETL Figure](docs/assets/readmefig.png) @@ -65,4 +65,4 @@ make tests-basic That's it! You've successfully installed DocETL and are ready to start processing documents. -For more detailed information on usage and configuration, please refer to our [documentation](https://shreyashankar.github.io/docetl). +For more detailed information on usage and configuration, please refer to our [documentation](https://ucbepic.github.io/docetl). diff --git a/docetl/api.py b/docetl/api.py index ccde6f64..ede1b032 100644 --- a/docetl/api.py +++ b/docetl/api.py @@ -47,9 +47,13 @@ import os from typing import List, Optional, Dict, Any, Union +import yaml + from docetl.builder import Optimizer from docetl.runner import DSLRunner +from rich import print + @dataclass class Dataset: @@ -76,6 +80,7 @@ class MapOp(BaseOp): num_retries_on_validate_failure: Optional[int] = None gleaning: Optional[Dict[str, Any]] = None drop_keys: Optional[List[str]] = None + timeout: Optional[int] = None @dataclass @@ -94,6 +99,7 @@ class ResolveOp(BaseOp): compare_batch_size: Optional[int] = None limit_comparisons: Optional[int] = None optimize: Optional[bool] = None + timeout: Optional[int] = None @dataclass @@ -111,6 +117,7 @@ class ReduceOp(BaseOp): fold_batch_size: Optional[int] = None value_sampling: Optional[Dict[str, Any]] = None verbose: Optional[bool] = None + timeout: Optional[int] = None @dataclass @@ -122,6 +129,7 @@ class ParallelMapOp(BaseOp): recursively_optimize: Optional[bool] = None sample_size: Optional[int] = None drop_keys: Optional[List[str]] = None + timeout: Optional[int] = None @dataclass @@ -134,6 +142,7 @@ class FilterOp(BaseOp): sample_size: Optional[int] = None validate: Optional[List[str]] = None num_retries_on_validate_failure: Optional[int] = None + timeout: Optional[int] = None @dataclass @@ -152,6 +161,7 @@ class EquijoinOp(BaseOp): compare_batch_size: Optional[int] = None limit_comparisons: Optional[int] = None blocking_keys: Optional[Dict[str, List[str]]] = None + timeout: Optional[int] = None @dataclass @@ -274,6 +284,22 @@ def run(self, max_threads: Optional[int] = None) -> float: result = runner.run() return result + def to_yaml(self, path: str) -> None: + """ + Convert the Pipeline object to a YAML string and save it to a file. + + Args: + path (str): Path to save the YAML file. + + Returns: + None + """ + config = self._to_dict() + with open(path, "w") as f: + yaml.safe_dump(config, f) + + print(f"[green]Pipeline saved to {path}[/green]") + def _to_dict(self) -> Dict[str, Any]: """ Convert the Pipeline object to a dictionary representation. diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index ab9f3d1b..f61347f0 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -15,7 +15,6 @@ from jinja2 import Template from litellm import embedding, model_cost from docetl.utils import completion_cost -from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation from docetl.operations.utils import ( @@ -55,7 +54,12 @@ def process_left_item( def compare_pair( - comparison_prompt: str, model: str, item1: Dict, item2: Dict + comparison_prompt: str, + model: str, + item1: Dict, + item2: Dict, + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Tuple[bool, float]: """ Compares two items using an LLM model to determine if they match. @@ -65,6 +69,8 @@ def compare_pair( model (str): The LLM model to use for comparison. item1 (Dict): The first item to compare. item2 (Dict): The second item to compare. + timeout_seconds (int): The timeout for the LLM call in seconds. + max_retries_per_timeout (int): The maximum number of retries per timeout. Returns: Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison. @@ -77,6 +83,8 @@ def compare_pair( "compare", [{"role": "user", "content": prompt}], {"is_match": "bool"}, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, ) output = parse_llm_response(response)[0] return output["is_match"], completion_cost(response) @@ -279,6 +287,8 @@ def get_embeddings( ) # Compute all cosine similarities in one call + from sklearn.metrics.pairwise import cosine_similarity + similarities = cosine_similarity(left_embeddings, right_embeddings) # Additional blocking based on embeddings @@ -383,6 +393,8 @@ def get_embeddings( self.config.get("comparison_model", self.default_model), left, right, + self.config.get("timeout", 120), + self.config.get("max_retries_per_timeout", 2), ): (left, right) for left, right in blocked_pairs } diff --git a/docetl/operations/filter.py b/docetl/operations/filter.py index 98fa52c8..aee1df27 100644 --- a/docetl/operations/filter.py +++ b/docetl/operations/filter.py @@ -135,6 +135,10 @@ def validation_fn(response: Dict[str, Any]): messages, self.config["output"]["schema"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ), validation_fn=validation_fn, val_rule=self.config.get("validate", []), diff --git a/docetl/operations/map.py b/docetl/operations/map.py index a4afd7f8..15953b20 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -154,6 +154,10 @@ def validation_fn(response: Dict[str, Any]): self.config["gleaning"]["validation_prompt"], self.config["gleaning"]["num_rounds"], self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ), validation_fn=validation_fn, val_rule=self.config.get("validate", []), @@ -170,6 +174,10 @@ def validation_fn(response: Dict[str, Any]): self.config["output"]["schema"], tools=self.config.get("tools", None), console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ), validation_fn=validation_fn, val_rule=self.config.get("validate", []), @@ -260,18 +268,12 @@ def syntax_check(self) -> None: if not isinstance(prompt_config, dict): raise TypeError(f"Prompt configuration {i} must be a dictionary") - required_keys = ["name", "prompt", "output_keys"] + required_keys = ["prompt", "output_keys"] for key in required_keys: if key not in prompt_config: raise ValueError( f"Missing required key '{key}' in prompt configuration {i}" ) - - if not isinstance(prompt_config["name"], str): - raise TypeError( - f"'name' in prompt configuration {i} must be a string" - ) - if not isinstance(prompt_config["prompt"], str): raise TypeError( f"'prompt' in prompt configuration {i} must be a string" @@ -362,6 +364,8 @@ def process_prompt(item, prompt_config): local_output_schema, tools=prompt_config.get("tools", None), console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) output = parse_llm_response( response, tools=prompt_config.get("tools", None) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 56ca6b4c..df9d8b16 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -19,8 +19,6 @@ from jinja2 import Template from docetl.utils import completion_cost from litellm import embedding -from sklearn.cluster import KMeans -from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation from docetl.operations.utils import ( @@ -60,6 +58,7 @@ def __init__(self, *args, **kwargs): if isinstance(self.config["reduce_key"], str) else self.config["reduce_key"] ) + self.intermediates = {} def syntax_check(self) -> None: """ @@ -351,6 +350,14 @@ def process_group( if output is not None: results.append(output) + if self.config.get("persist_intermediates", False): + for result in results: + key = tuple(result[k] for k in self.config["reduce_key"]) + if key in self.intermediates: + result[f"_{self.config['name']}_intermediates"] = ( + self.intermediates[key] + ) + return results, total_cost def _get_embeddings( @@ -383,6 +390,8 @@ def _cluster_based_sampling( ) -> Tuple[List[Dict], float]: embeddings, cost = self._get_embeddings(group_list, value_sampling) + from sklearn.cluster import KMeans + kmeans = KMeans(n_clusters=sample_size, random_state=42) cluster_labels = kmeans.fit_predict(embeddings) @@ -411,6 +420,8 @@ def _semantic_similarity_sampling( query_embedding = query_response["data"][0]["embedding"] cost += completion_cost(query_response) + from sklearn.metrics.pairwise import cosine_similarity + similarities = cosine_similarity([query_embedding], embeddings)[0] top_k_indices = np.argsort(similarities)[-sample_size:] @@ -464,6 +475,10 @@ def calculate_num_parallel_folds(): fold_results = [] remaining_items = group_list + if self.config.get("persist_intermediates", False): + self.intermediates[key] = [] + iter_count = 0 + # Parallel folding and merging with ThreadPoolExecutor(max_workers=self.max_threads) as executor: while remaining_items: @@ -485,6 +500,15 @@ def calculate_num_parallel_folds(): total_cost += cost if result is not None: new_fold_results.append(result) + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": result, + "scratchpad": result["updated_scratchpad"], + } + ) + iter_count += 1 # Update fold_results with new results fold_results = new_fold_results + fold_results[len(new_fold_results) :] @@ -507,6 +531,15 @@ def calculate_num_parallel_folds(): total_cost += cost if result is not None: new_results.append(result) + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": result, + "scratchpad": None, + } + ) + iter_count += 1 fold_results = new_results @@ -538,6 +571,15 @@ def calculate_num_parallel_folds(): total_cost += cost if result is not None: new_results.append(result) + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": result, + "scratchpad": None, + } + ) + iter_count += 1 fold_results = new_results @@ -567,6 +609,10 @@ def _incremental_reduce( num_folds = (len(group_list) + fold_batch_size - 1) // fold_batch_size scratchpad = "" + if self.config.get("persist_intermediates", False): + self.intermediates[key] = [] + iter_count = 0 + for i in range(0, len(group_list), fold_batch_size): # Log the current iteration and total number of folds current_fold = i // fold_batch_size + 1 @@ -584,6 +630,16 @@ def _incremental_reduce( if folded_output is None: continue + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": folded_output, + "scratchpad": folded_output["updated_scratchpad"], + } + ) + iter_count += 1 + # Pop off updated_scratchpad if "updated_scratchpad" in folded_output: scratchpad = folded_output["updated_scratchpad"] @@ -635,6 +691,8 @@ def _increment_fold( self.config["output"]["schema"], scratchpad=scratchpad, console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) folded_output = parse_llm_response(response)[0] @@ -674,6 +732,8 @@ def _merge_results( [{"role": "user", "content": merge_prompt}], self.config["output"]["schema"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) merged_output = parse_llm_response(response)[0] merged_output.update(dict(zip(self.config["reduce_key"], key))) @@ -766,6 +826,8 @@ def _batch_reduce( self.config["gleaning"]["validation_prompt"], self.config["gleaning"]["num_rounds"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) item_cost += gleaning_cost else: @@ -776,6 +838,8 @@ def _batch_reduce( self.config["output"]["schema"], console=self.console, scratchpad=scratchpad, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get("max_retries_per_timeout", 2), ) item_cost += completion_cost(response) diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 9903d271..5d0666af 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -12,7 +12,6 @@ from jinja2 import Template from docetl.utils import completion_cost from litellm import embedding -from sklearn.metrics.pairwise import cosine_similarity from docetl.operations.base import BaseOperation from docetl.operations.utils import ( @@ -23,6 +22,7 @@ validate_output, gen_embedding, ) +from rich.prompt import Confirm def compare_pair( @@ -31,6 +31,8 @@ def compare_pair( item1: Dict, item2: Dict, blocking_keys: List[str] = [], + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Tuple[bool, float]: """ Compares two items using an LLM model to determine if they match. @@ -58,6 +60,8 @@ def compare_pair( "compare", [{"role": "user", "content": prompt}], {"is_match": "bool"}, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, ) output = parse_llm_response(response)[0] return output["is_match"], completion_cost(response) @@ -195,6 +199,22 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: blocking_keys = self.config.get("blocking_keys", []) blocking_threshold = self.config.get("blocking_threshold") blocking_conditions = self.config.get("blocking_conditions", []) + + if not blocking_threshold and not blocking_conditions: + # Prompt the user for confirmation + if self.status: + self.status.stop() + if not Confirm.ask( + f"[yellow]Warning: No blocking keys or conditions specified. " + f"This may result in a large number of comparisons. " + f"We recommend specifying at least one blocking key or condition, or using the optimizer to automatically come up with these. " + f"Do you want to continue without blocking?[/yellow]", + ): + raise ValueError("Operation cancelled by user.") + + if self.status: + self.status.start() + input_schema = self.config.get("input", {}).get("schema", {}) if not blocking_keys: # Set them to all keys in the input data @@ -294,6 +314,8 @@ def meets_blocking_conditions(pair): ) if remaining_comparisons > 0 and blocking_threshold is not None: # Compute cosine similarity for all pairs efficiently + from sklearn.metrics.pairwise import cosine_similarity + similarity_matrix = cosine_similarity(embeddings) cosine_pairs = [] @@ -344,6 +366,10 @@ def meets_blocking_conditions(pair): input_data[pair[0]], input_data[pair[1]], blocking_keys, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ): pair for pair in batch } @@ -382,6 +408,10 @@ def process_cluster(cluster): [{"role": "user", "content": resolution_prompt}], self.config["output"]["schema"], console=self.console, + timeout_seconds=self.config.get("timeout", 120), + max_retries_per_timeout=self.config.get( + "max_retries_per_timeout", 2 + ), ) reduction_output = parse_llm_response(reduction_response)[0] reduction_cost = completion_cost(reduction_response) diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index f05046c0..9035f561 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -6,7 +6,6 @@ import threading from concurrent.futures import as_completed from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -from openai import OpenAI from dotenv import load_dotenv from frozendict import frozendict @@ -30,8 +29,6 @@ LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache") cache = Cache(LLM_CACHE_DIR) -client = OpenAI() - def freezeargs(func): """ @@ -366,6 +363,8 @@ def call_llm( tools: Optional[List[Dict[str, str]]] = None, scratchpad: Optional[str] = None, console: Console = Console(), + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Any: """ Wrapper function that uses caching for LLM calls. @@ -380,6 +379,8 @@ def call_llm( output_schema (Dict[str, str]): The output schema dictionary. tools (Optional[List[Dict[str, str]]]): The tools to pass to the LLM. scratchpad (Optional[str]): The scratchpad to use for the operation. + timeout_seconds (int): The timeout for the LLM call. + max_retries_per_timeout (int): The maximum number of retries per timeout. Returns: str: The result from the cached LLM call. @@ -388,10 +389,10 @@ def call_llm( """ key = cache_key(model, op_type, messages, output_schema, scratchpad) - max_retries = 2 - for attempt in range(max_retries): + max_retries = max_retries_per_timeout + for attempt in range(max_retries + 1): try: - return timeout(120)(cached_call_llm)( + return timeout(timeout_seconds)(cached_call_llm)( key, model, op_type, @@ -610,6 +611,8 @@ def call_llm_with_gleaning( validator_prompt_template: str, num_gleaning_rounds: int, console: Console = Console(), + timeout_seconds: int = 120, + max_retries_per_timeout: int = 2, ) -> Tuple[str, float]: """ Call LLM with a gleaning process, including validation and improvement rounds. @@ -624,7 +627,7 @@ def call_llm_with_gleaning( output_schema (Dict[str, str]): The output schema dictionary. validator_prompt_template (str): Template for the validator prompt. num_gleaning_rounds (int): Number of gleaning rounds to perform. - + timeout_seconds (int): The timeout for the LLM call. Returns: Tuple[str, float]: A tuple containing the final LLM response and the total cost. """ @@ -635,7 +638,15 @@ def call_llm_with_gleaning( parameters["additionalProperties"] = False # Initial LLM call - response = call_llm(model, op_type, messages, output_schema, console=console) + response = call_llm( + model, + op_type, + messages, + output_schema, + console=console, + timeout_seconds=timeout_seconds, + max_retries_per_timeout=max_retries_per_timeout, + ) cost = 0.0 @@ -792,16 +803,16 @@ def parse_llm_response( if "tool_calls" in dir(response.choices[0].message): # Default behavior for write_output function tool_calls = response.choices[0].message.tool_calls + if not tool_calls: raise ValueError("No tool calls found in response") outputs = [] for tool_call in tool_calls: - if tool_call.function.name == "write_output": - try: - outputs.append(json.loads(tool_call.function.arguments)) - except json.JSONDecodeError: - return [{}] + try: + outputs.append(json.loads(tool_call.function.arguments)) + except json.JSONDecodeError: + return [{}] return outputs else: diff --git a/docetl/runner.py b/docetl/runner.py index 2c6ceb73..dc783f1c 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -229,7 +229,11 @@ def execute_step( operation_class = get_operation(op_object["type"]) operation_instance = operation_class( - op_object, self.default_model, self.max_threads, self.console + op_object, + self.default_model, + self.max_threads, + self.console, + self.status, ) if op_object["type"] == "equijoin": left_data = self.datasets[op_object["left"]] diff --git a/docs/examples/ollama.md b/docs/examples/ollama.md new file mode 100644 index 00000000..40f336c8 --- /dev/null +++ b/docs/examples/ollama.md @@ -0,0 +1,140 @@ +# Medical Document Classification with Ollama + +This tutorial demonstrates how to use DocETL with [Ollama](https://github.com/ollama/ollama) models to classify medical documents into predefined categories. We'll use a simple map operation to process a set of medical records, ensuring that sensitive information remains private by using a locally-run model. + +## Setup + +!!! note "Prerequisites" + + Before we begin, make sure you have Ollama installed and running on your local machine. + +You'll need to set the OLLAMA_API_BASE environment variable: + +```bash +export OLLAMA_API_BASE=http://localhost:11434/ +``` + +!!! info "API Details" + + For more information on the Ollama REST API, refer to the [Ollama documentation](https://github.com/ollama/ollama?tab=readme-ov-file#rest-api). + +## Pipeline Configuration + +Let's create a pipeline that classifies medical documents into categories such as "Cardiology", "Neurology", "Oncology", etc. + +!!! example "Initial Pipeline Configuration" + + ```yaml + datasets: + medical_records: + type: file + path: "medical_records.json" + + default_model: ollama/llama3 + + operations: + - name: classify_medical_record + type: map + output: + schema: + categories: "list[str]" + prompt: | + Classify the following medical record into one or more of these categories: Cardiology, Neurology, Oncology, Pediatrics, Orthopedics. + + Medical Record: + {{ input.text }} + + Return your answer as a JSON list of strings, e.g., ["Cardiology", "Neurology"]. + + pipeline: + steps: + - name: medical_classification + input: medical_records + operations: + - classify_medical_record + + output: + type: file + path: "classified_records.json" + ``` + +## Running the Pipeline with a Sample + +To test our pipeline and estimate the required timeout, we'll first run it on a sample of documents. + +Modify the `classify_medical_record` operation in your configuration to include a `sample` parameter: + +```yaml +operations: + - name: classify_medical_record + type: map + sample: 5 + output: + schema: + categories: "list[str]" + prompt: | + Classify the following medical record into one or more of these categories: Cardiology, Neurology, Oncology, Pediatrics, Orthopedics. + + Medical Record: + {{ input.text }} + + Return your answer as a JSON list of strings, e.g., ["Cardiology", "Neurology"]. +``` + +Now, run the pipeline with this sample configuration: + +```bash +docetl run pipeline.yaml +``` + +## Adjusting the Timeout + +After running the sample, note the time it took to process 5 documents. + +!!! example "Timeout Calculation" + + Let's say it took 100 seconds to process 5 documents. You can use this to estimate the time needed for your full dataset. For example, if you have 1000 documents in total, you might want to set the timeout to: + + (100 seconds / 5 documents) * 1000 documents = 20,000 seconds + +Now, adjust your pipeline configuration to include this timeout and remove the sample parameter: + +```yaml +operations: + - name: classify_medical_record + type: map + timeout: 20000 + output: + schema: + categories: "list[str]" + prompt: | + Classify the following medical record into one or more of these categories: Cardiology, Neurology, Oncology, Pediatrics, Orthopedics. + Medical Record: + {{ input.text }} + Return your answer as a JSON list of strings, e.g., ["Cardiology", "Neurology"]. +``` + +!!! note "Caching" + + DocETL caches results (even between runs), so if the same document is processed again, the answer will be returned from the cache rather than processed again (significantly speeding up processing). + +## Running the Full Pipeline + +Now you can run the full pipeline with the adjusted timeout: + +```bash +docetl run pipeline.yaml +``` + +This will process all your medical records, classifying them into the predefined categories. + +## Conclusion + +!!! success "Key Takeaways" + + - This pipeline demonstrates how to use Ollama with DocETL for local processing of sensitive data. + - Ollama integrates into multi-operation pipelines, maintaining data privacy. + - Ollama is a local model, so it is much slower than leveraging an LLM API like OpenAI. Adjust the timeout accordingly. + - DocETL's sample and timeout parameters help optimize the pipeline for efficient use of Ollama's capabilities. + +For more information, e.g., for specific models, visit [https://ollama.com/](https://ollama.com/). diff --git a/docs/operators/filter.md b/docs/operators/filter.md index aaa83c2a..c3d51c31 100644 --- a/docs/operators/filter.md +++ b/docs/operators/filter.md @@ -89,6 +89,8 @@ This example demonstrates how the Filter operation distinguishes between high-im | `sample_size` | Number of samples to use for the operation | Processes all data | | `validate` | List of Python expressions to validate the output | None | | `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | !!! info "Validation" diff --git a/docs/operators/map.md b/docs/operators/map.md index 90afcc4f..8b536f56 100644 --- a/docs/operators/map.md +++ b/docs/operators/map.md @@ -142,9 +142,13 @@ This example demonstrates how the Map operation can transform long, unstructured | `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 | | `gleaning` | Configuration for advanced validation and LLM-based refinement | None | | `drop_keys` | List of keys to drop from the input before processing | None | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters. +| `timeout` | Timeout for each LLM call in seconds | 120 | + !!! info "Validation and Gleaning" For more details on validation techniques and implementation, see [operators](../concepts/operators.md#validation). diff --git a/docs/operators/parallel-map.md b/docs/operators/parallel-map.md index 3c6d713f..a4e0f6f1 100644 --- a/docs/operators/parallel-map.md +++ b/docs/operators/parallel-map.md @@ -23,26 +23,28 @@ The output schema should include all the fields generated by the individual prom Each prompt configuration in the `prompts` list should contain: -- `name`: A unique name for the prompt - `prompt`: The prompt template to use for the transformation - `output_keys`: List of keys that this prompt will generate - `model` (optional): The language model to use for this specific prompt ### Optional Parameters -| Parameter | Description | Default | -| ---------------------- | ------------------------------------------ | ----------------------------- | -| `model` | The default language model to use | Falls back to `default_model` | -| `optimize` | Flag to enable operation optimization | True | -| `recursively_optimize` | Flag to enable recursive optimization | false | -| `sample_size` | Number of samples to use for the operation | Processes all data | +| Parameter | Description | Default | +| ------------------------- | ------------------------------------------ | ----------------------------- | +| `model` | The default language model to use | Falls back to `default_model` | +| `optimize` | Flag to enable operation optimization | True | +| `recursively_optimize` | Flag to enable recursive optimization | false | +| `sample_size` | Number of samples to use for the operation | Processes all data | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | ??? question "Why use Parallel Map instead of multiple Map operations?" -While you could achieve similar results with multiple Map operations, Parallel Map offers several advantages: - 1. **Concurrency**: Prompts run in parallel, potentially reducing overall processing time. - 2. **Simplified Configuration**: You define multiple transformations in a single operation, reducing pipeline complexity. - 3. **Unified Output**: Results from all prompts are combined into a single output item, simplifying downstream operations. + While you could achieve similar results with multiple Map operations, Parallel Map offers several advantages: + + 1. **Concurrency**: Prompts run in parallel, potentially reducing overall processing time. + 2. **Simplified Configuration**: You define multiple transformations in a single operation, reducing pipeline complexity. + 3. **Unified Output**: Results from all prompts are combined into a single output item, simplifying downstream operations. ## 🚀 Example: Processing Job Applications diff --git a/docs/operators/reduce.md b/docs/operators/reduce.md index 331deb34..018f1cde 100644 --- a/docs/operators/reduce.md +++ b/docs/operators/reduce.md @@ -49,17 +49,20 @@ This Reduce operation processes customer feedback grouped by department: ### Optional Parameters -| Parameter | Description | Default | -| -------------------- | ------------------------------------------------------------------------------- | --------------------------- | -| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true | -| `model` | The language model to use | Falls back to default_model | -| `input` | Specifies the schema or keys to subselect from each item | All keys from input items | -| `pass_through` | If true, non-input keys from the first item in the group will be passed through | false | -| `associative` | If true, the reduce operation is associative (i.e., order doesn't matter) | true | -| `fold_prompt` | A prompt template for incremental folding | None | -| `fold_batch_size` | Number of items to process in each fold operation | None | -| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None | -| `verbose` | If true, enables detailed logging of the reduce operation | false | +| Parameter | Description | Default | +| ------------------------- | ------------------------------------------------------------------------------------------------------ | --------------------------- | +| `synthesize_resolve` | If false, won't synthesize a resolve operation between map and reduce | true | +| `model` | The language model to use | Falls back to default_model | +| `input` | Specifies the schema or keys to subselect from each item | All keys from input items | +| `pass_through` | If true, non-input keys from the first item in the group will be passed through | false | +| `associative` | If true, the reduce operation is associative (i.e., order doesn't matter) | true | +| `fold_prompt` | A prompt template for incremental folding | None | +| `fold_batch_size` | Number of items to process in each fold operation | None | +| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None | +| `verbose` | If true, enables detailed logging of the reduce operation | false | +| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | ## Advanced Features diff --git a/docs/operators/resolve.md b/docs/operators/resolve.md index c4447ff3..722e8e31 100644 --- a/docs/operators/resolve.md +++ b/docs/operators/resolve.md @@ -107,18 +107,20 @@ After determining eligible pairs for comparison, the Resolve operation uses a Un ## Optional Parameters -| Parameter | Description | Default | -| ---------------------- | --------------------------------------------------------------------------------- | ----------------------------- | -| `embedding_model` | The model to use for creating embeddings | Falls back to `default_model` | -| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` | -| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` | -| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data | -| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None | -| `blocking_conditions` | List of conditions for initial blocking | [] | -| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items | -| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 | -| `compare_batch_size` | The number of entity pairs processed in each batch during the comparison phase | 100 | -| `limit_comparisons` | Maximum number of comparisons to perform | None | +| Parameter | Description | Default | +| ------------------------- | --------------------------------------------------------------------------------- | ----------------------------- | +| `embedding_model` | The model to use for creating embeddings | Falls back to `default_model` | +| `resolution_model` | The language model to use for reducing matched entries | Falls back to `default_model` | +| `comparison_model` | The language model to use for comparing potential matches | Falls back to `default_model` | +| `blocking_keys` | List of keys to use for initial blocking | All keys in the input data | +| `blocking_threshold` | Embedding similarity threshold for considering entries as potential matches | None | +| `blocking_conditions` | List of conditions for initial blocking | [] | +| `input` | Specifies the schema or keys to subselect from each item to pass into the prompts | All keys from input items | +| `embedding_batch_size` | The number of entries to send to the embedding model at a time | 1000 | +| `compare_batch_size` | The number of entity pairs processed in each batch during the comparison phase | 100 | +| `limit_comparisons` | Maximum number of comparisons to perform | None | +| `timeout` | Timeout for each LLM call in seconds | 120 | +| `max_retries_per_timeout` | Maximum number of retries per timeout | 2 | ## Best Practices diff --git a/mkdocs.yml b/mkdocs.yml index 34f79b8c..41c60bff 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -45,17 +45,18 @@ nav: # - User-Defined Functions: advanced/custom-operators.md # - Extending Optimizer Agents: advanced/extending-agents.md # - Performance Tuning: advanced/performance-tuning.md + - Examples: + - Reporting on Themes from Presidential Debates: examples/presidential-debate-themes.md + - Mining Product Reviews for Polarizing Features: examples/mining-product-reviews.md + - Medical Document Classification with Ollama: examples/ollama.md + # - Annotating Legal Documents: examples/annotating-legal-documents.md + # - Characterizing Troll Behavior on Wikipedia: examples/characterizing-troll-behavior.md - API Reference: - docetl: api-reference/docetl.md - docetl.cli: api-reference/cli.md - docetl.operations: api-reference/operations.md - docetl.optimizers: api-reference/optimizers.md - Python API: api-reference/python.md - - Examples: - - Reporting on Themes from Presidential Debates: examples/presidential-debate-themes.md - - Mining Product Reviews for Polarizing Features: examples/mining-product-reviews.md - # - Annotating Legal Documents: examples/annotating-legal-documents.md - # - Characterizing Troll Behavior on Wikipedia: examples/characterizing-troll-behavior.md - Community: - Community: community/index.md - Roadmap: community/roadmap.md diff --git a/pyproject.toml b/pyproject.toml index 66025cc7..fe55a4c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "docetl" -version = "0.1.2" +version = "0.1.3" description = "ETL with LLM operations." authors = ["Shreya Shankar "] license = "MIT" diff --git a/tests/test_map.py b/tests/test_map.py index a5346a04..ccf390a4 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -55,3 +55,50 @@ def test_map_operation_with_word_count_tool(map_config_with_tools, synthetic_dat assert all("word_count" in result for result in results) assert [result["word_count"] for result in results] == [5, 6, 5, 1] assert cost > 0 # Ensure there was some cost associated with the operation + + +@pytest.fixture +def simple_map_config(): + return { + "name": "simple_sentiment_analysis", + "type": "map", + "prompt": "Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", + "output": {"schema": {"sentiment": "string"}}, + "model": "gpt-4o-mini", + } + + +@pytest.fixture +def simple_sample_data(): + import random + import string + + def generate_random_text(length): + return "".join( + random.choice( + string.ascii_letters + string.digits + string.punctuation + " " + ) + for _ in range(length) + ) + + return [ + {"text": generate_random_text(random.randint(20, 100000))}, + {"text": generate_random_text(random.randint(20, 100000))}, + {"text": generate_random_text(random.randint(20, 100000))}, + ] + + +def test_map_operation_with_timeout(simple_map_config, simple_sample_data): + # Add timeout to the map configuration + map_config_with_timeout = { + **simple_map_config, + "timeout": 1, + "max_retries_per_timeout": 0, + } + + operation = MapOperation(map_config_with_timeout, "gpt-4o-mini", 4) + + # Execute the operation and expect empty results + results, cost = operation.execute(simple_sample_data) + for result in results: + assert "sentiment" not in result diff --git a/tests/test_ollama.py b/tests/test_ollama.py new file mode 100644 index 00000000..81d8fbfb --- /dev/null +++ b/tests/test_ollama.py @@ -0,0 +1,113 @@ +import shutil +import pytest +import json +import tempfile +import os +from docetl.api import ( + Pipeline, + Dataset, + MapOp, + ReduceOp, + PipelineStep, + PipelineOutput, +) +from dotenv import load_dotenv + +load_dotenv() + +# Set the OLLAMA_API_BASE environment variable +os.environ["OLLAMA_API_BASE"] = "http://localhost:11434/" + + +@pytest.fixture +def temp_input_file(): + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tmp: + json.dump( + [ + {"text": "This is a test", "group": "A"}, + {"text": "Another test", "group": "B"}, + ], + tmp, + ) + yield tmp.name + os.unlink(tmp.name) + + +@pytest.fixture +def temp_output_file(): + with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp: + pass + yield tmp.name + os.unlink(tmp.name) + + +@pytest.fixture +def temp_intermediate_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +@pytest.fixture +def map_config(): + return MapOp( + name="sentiment_analysis", + type="map", + prompt="Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", + output={"schema": {"sentiment": "string"}}, + model="ollama/llama3", + ) + + +@pytest.fixture +def reduce_config(): + return ReduceOp( + name="group_summary", + type="reduce", + reduce_key="group", + prompt="Summarize the following group of values: {{ inputs }} Provide a total and any other relevant statistics.", + output={"schema": {"total": "number", "avg": "number"}}, + model="ollama/llama3", + ) + + +@pytest.fixture(autouse=True) +def remove_openai_api_key(): + openai_api_key = os.environ.pop("OPENAI_API_KEY", None) + yield + if openai_api_key: + os.environ["OPENAI_API_KEY"] = openai_api_key + + +def test_ollama_map_reduce_pipeline( + map_config, reduce_config, temp_input_file, temp_output_file, temp_intermediate_dir +): + pipeline = Pipeline( + name="test_ollama_pipeline", + datasets={"test_input": Dataset(type="file", path=temp_input_file)}, + operations=[map_config, reduce_config], + steps=[ + PipelineStep( + name="pipeline", + input="test_input", + operations=["sentiment_analysis", "group_summary"], + ), + ], + output=PipelineOutput( + type="file", path=temp_output_file, intermediate_dir=temp_intermediate_dir + ), + default_model="ollama/llama3", + ) + + cost = pipeline.run() + + assert isinstance(cost, float) + assert cost == 0 + + # Verify output file exists and contains data + assert os.path.exists(temp_output_file) + with open(temp_output_file, "r") as f: + output_data = json.load(f) + assert len(output_data) > 0 + + # Clean up + shutil.rmtree(temp_intermediate_dir) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index bb528a5c..c8fdf65b 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -194,3 +194,71 @@ def test_reduce_operation_non_associative(default_model, max_threads): assert combined_result.index("brave princess") < combined_result.index( "dragon" ), "Princess should be mentioned before the dragon in the story" + + +def test_reduce_operation_persist_intermediates(default_model, max_threads): + # Define a config with persist_intermediates enabled + persist_intermediates_config = { + "name": "persist_intermediates_reduce", + "type": "reduce", + "reduce_key": "group", + "persist_intermediates": True, + "prompt": "Summarize the numbers in '{{ inputs }}'.", + "fold_prompt": "Combine summaries: Previous '{{ output }}', New '{{ inputs[0] }}'.", + "fold_batch_size": 2, + "output": {"schema": {"summary": "string"}}, + } + + # Sample data with more items than fold_batch_size + sample_data = [ + {"group": "numbers", "value": "1, 2"}, + {"group": "numbers", "value": "3, 4"}, + {"group": "numbers", "value": "5, 6"}, + {"group": "numbers", "value": "7, 8"}, + {"group": "numbers", "value": "9, 10"}, + ] + + operation = ReduceOperation( + persist_intermediates_config, default_model, max_threads + ) + results, cost = operation.execute(sample_data) + + assert len(results) == 1, "Should have one result for the 'numbers' group" + assert cost > 0, "Cost should be greater than 0" + + result = results[0] + assert "summary" in result, "Result should have a 'summary' key" + + # Check if intermediates were persisted + assert ( + "_persist_intermediates_reduce_intermediates" in result + ), "Result should have '_persist_intermediates_reduce_intermediates' key" + intermediates = result["_persist_intermediates_reduce_intermediates"] + assert isinstance(intermediates, list), "Intermediates should be a list" + assert len(intermediates) > 1, "Should have multiple intermediate results" + + # Check the structure of intermediates + for intermediate in intermediates: + assert "iter" in intermediate, "Each intermediate should have an 'iter' key" + assert ( + "intermediate" in intermediate + ), "Each intermediate should have an 'intermediate' key" + assert ( + "scratchpad" in intermediate + ), "Each intermediate should have a 'scratchpad' key" + + # Verify that the intermediate results are stored in the correct order + for i in range(1, len(intermediates)): + assert ( + intermediates[i]["iter"] > intermediates[i - 1]["iter"] + ), "Intermediate results should be in ascending order of iterations" + + # Check if the intermediate results are accessible via the special key + for result in results: + assert ( + f"_persist_intermediates_reduce_intermediates" in result + ), "Result should contain the special intermediate key" + stored_intermediates = result[f"_persist_intermediates_reduce_intermediates"] + assert ( + stored_intermediates == intermediates + ), "Stored intermediates should match the operation's intermediates"