8000 Bugfix/aws refacto by MartBakler · Pull Request #125 · Tanuki/tanuki.py · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Bugfix/aws refacto #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 17, 2024
8 changes: 5 additions & 3 deletions src/tanuki/function_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import datetime
import io
import json
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Union

import openai

Expand Down Expand Up @@ -50,7 +50,7 @@ def _get_dataset_info(self, dataset_type, func_hash, type="length"):
return self.data_worker.load_dataset(dataset_type, func_hash, return_type=type)

def _configure_teacher_models(self,
teacher_models: list[str, BaseModelConfig],
teacher_models: List[Union[str, BaseModelConfig]],
func_hash: str,
task_type: str):
"""
Expand All @@ -77,6 +77,8 @@ def _configure_teacher_models(self,
# currently ban all non-openai models from finetuning because it doesnt make sense
if model_config.provider != OPENAI_PROVIDER and func_hash not in self.check_finetune_blacklist:
self.check_finetune_blacklist.append(func_hash)
if model_config.provider != OPENAI_PROVIDER and func_hash not in self.execute_finetune_blacklist:
self.execute_finetune_blacklist.append(func_hash)

def _get_datasets(self):
"""
Expand Down Expand Up @@ -552,7 +554,7 @@ def _check_finetuning_status(self, func_hash):
last_checked = self.function_configs[func_hash].current_training_run["last_checked"]
# check if last checked was more than 30 mins ago
if (datetime.datetime.now() - datetime.datetime.strptime(last_checked,
"%Y-%m-%d %H:%M:%S")).total_seconds() > 1:
"%Y-%m-%d %H:%M:%S")).total_seconds() > 1800:
finetune_provider = self.function_configs[func_hash].distilled_model.provider
response = self.api_provider[finetune_provider].get_finetuned(job_id)
self.function_configs[func_hash].current_training_run["last_checked"] = datetime.datetime.now().strftime(
Expand Down
2 changes: 1 addition & 1 deletion src/tanuki/language_models/language_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_generation_case(self, args, kwargs, function_description, llm_parameters
distilled_model, teacher_models = self.function_modeler.get_models(function_description)
is_distilled_model = distilled_model.model_name != ""
suitable_for_distillation, input_prompt_token_count = self.suitable_for_finetuning_token_check(args, kwargs, f,
distilled_model.context_length)
distilled_model)
# no examples needed, using a finetuned model. Dont save to finetune dataset
if is_distilled_model and suitable_for_distillation:
prompt = self.construct_prompt(f, args, kwargs, [], distilled_model)
Expand Down
4 changes: 2 additions & 2 deletions src/tanuki/language_models/llm_configs/abc_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ class BaseModelConfig(abc.ABC, BaseModel):
system_message : str = f"You are a skillful and accurate language model, who applies a described function on input data. Make sure the function is applied accurately and correctly and the outputs follow the output type hints and are valid outputs given the output types."
instructions : str = "You are given below a function description and input data. The function description of what the function must carry out can be found in the Function section, with input and output type hints. The input data can be found in Input section. Using the function description, apply the function to the Input and return a valid output type, that is acceptable by the output_class_definition and output_class_hint. Return None if you can't apply the function to the input or if the output is optional and the correct output is None.\nINCREDIBLY IMPORTANT: Only output a JSON-compatible string in the correct response format."
repair_instruction: str = "Below are an outputs of a function applied to inputs, which failed type validation. The input to the function is brought out in the INPUT section and function description is brought out in the FUNCTION DESCRIPTION section. Your task is to apply the function to the input and return a correct output in the right type. The FAILED EXAMPLES section will show previous outputs of this function applied to the data, which failed type validation and hence are wrong outputs. Using the input and function description output the accurate output following the output_class_definition and output_type_hint attributes of the function description, which define the output type. Make sure the output is an accurate function output and in the correct type. Return None if you can't apply the function to the input or if the output is optional 10000 and the correct output is None."
system_message_token_count = -1
instruction_token_count = -1
system_message_token_count: int = -1
instruction_token_count: int = -1
parsing_helper_tokens: Optional[dict] = {"start_token": "", "end_token": ""}
21 changes: 13 additions & 8 deletions src/tanuki/language_models/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,16 @@ def list_finetuned(self, limit=100, **kwargs) -> List[FinetuneJob]:
response = self.client.fine_tuning.jobs.list(limit=limit)
jobs = []
for job in response.data:
model_config = copy.deepcopy(DEFAULT_GENERATIVE_MODELS[DEFAULT_DISTILLED_MODEL_NAME])
model_config.model_name = job.fine_tuned_model
jobs.append(FinetuneJob(job.id, job.status, model_config))
finetune_job = self.create_finetune_job(job)
jobs.append(finetune_job)

return jobs

def get_finetuned(self, job_id):
def get_finetuned(self, job_id) -> FinetuneJob:
self.check_api_key()
return self.client.fine_tuning.jobs.retrieve(job_id)
response = self.client.fine_tuning.jobs.retrieve(job_id)
finetune_job = self.create_finetune_job(response)
return finetune_job

def finetune(self, file, suffix, **kwargs) -> FinetuneJob:
self.check_api_key()
Expand All @@ -164,11 +165,15 @@ def finetune(self, file, suffix, **kwargs) -> FinetuneJob:
suffix=suffix)
except Exception as e:
return

finetune_job = FinetuneJob(finetuning_response.id, finetuning_response.status, finetuning_response.fine_tuned_model)

finetune_job = self.create_finetune_job(finetuning_response)
return finetune_job

def create_finetune_job(self, response: FineTuningJob) -> FinetuneJob:
finetuned_model_config = copy.deepcopy(DEFAULT_GENERATIVE_MODELS[DEFAULT_DISTILLED_MODEL_NAME])
finetuned_model_config.model_name = response.fine_tuned_model
finetune_job = FinetuneJob(response.id, response.status, finetuned_model_config)
return finetune_job

def check_api_key(self):
# check if api key is not none
if not self.api_key:
Expand Down
2 changes: 1 addition & 1 deletion src/tanuki/models/function_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def load_from_dict(self, json_dict):
self.last_training_run = json_dict["last_training_run"]
self.current_training_run = json_dict["current_training_run"]
self.nr_of_training_runs = json_dict["nr_of_training_runs"]
if len(json_dict["teacher_models"]) > 0:
if "teacher_models" in json_dict and len(json_dict["teacher_models"]) > 0:
self.teacher_models = [config_factory.create_config(teacher_model, TEACHER_MODEL) for teacher_model in json_dict["teacher_models"]]
return self

Expand Down
27 changes: 27 additions & 0 deletions tests/test_configure_MP.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def func_default_openai(input: str) -> Optional[Literal['Good', 'Bad']]:
"""


@tanuki.patch
def func_default(input: str) -> Optional[Literal['Good', 'Bad']]:
"""
Determine if the input is positive or negative sentiment
"""

@tanuki.patch(teacher_models=[LlamaBedrockConfig(model_name = "llama778", context_length = 1)])
def func_full_llama_bedrock(input: str) -> Optional[Literal['Good', 'Bad']]:
"""
Expand Down Expand Up @@ -180,3 +186,24 @@ def test_teacher_model_override_error():
assert False
except:
assert True


def test_finetuning():
func_default_description = Register.load_function_description(func_default)
func_default_openai_description = Register.load_function_description(func_default_openai)
func_full_llama_bedrock_description = Register.load_function_description(func_full_llama_bedrock)
func_mixed_description = Register.load_function_description(func_mixed)
func_default_hash = func_default_description.__hash__()
func_default_openai_hash = func_default_openai_description.__hash__()
func_full_llama_bedrock_hash = func_full_llama_bedrock_description.__hash__()
func_mixed_hash = func_mixed_description.__hash__()

func_modeler = tanuki.function_modeler
assert func_default_hash not in func_modeler.check_finetune_blacklist
assert func_default_hash not in func_modeler.execute_finetune_blacklist
assert func_default_openai_hash not in func_modeler.check_finetune_blacklist
assert func_default_openai_hash not in func_modeler.execute_finetune_blacklist
assert func_full_llama_bedrock_hash in func_modeler.check_finetune_blacklist
assert func_full_llama_bedrock_hash in func_modeler.execute_finetune_blacklist
assert func_mixed_hash in func_modeler.check_finetune_blacklist
assert func_mixed_hash in func_modeler.execute_finetune_blacklist
0