This includes the original implementation of Divide, Reweight, and Conquer: A Logit Arithmetic Approach for In-Context Learning LARA is a novel framework that enhances in-context learning by dividing long input demonstrations into shorter, parallelizable subgroups and reweighting their logits using a non-gradient optimization approach.
The LARA framework divides the training examples into multiple subgroups and uses the gradient-free optimization algorithm CMA-ES to generate the weight vector w.
- 2024.10: Initial release of codes, models, and the paper.
Install dependent Python libraries by running the command below.
pip install -r requirements.txt
Alternatively, create a conda environment by running the command below:
conda env create -f environment.yml
To run the main experiments for B-LARA, use the following commands:
## BBH
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name BBH --subdataset_name date_understanding --train_length 32 --binary
## MMLU
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name MMLU --subdataset_name anatomy --train_length 32 --binary
## GoEmotion
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name emotion --train_length 8 --binary
## Tacred
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name tacred --train_length 4 --binary
To run the main experiments for LARA, use the following commands:
## BBH
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name BBH --subdataset_name date_understanding --train_length 32
## MMLU
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name MMLU --subdataset_name anatomy --train_length 32
## GoEmotion
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name emotion --train_length 8
## Tacred
python hyper_search_loss.py --model meta-llama/Meta-Llama-3.1-8B --output_path llama31 --dataset_name tacred --train_length 4
To implement LARA for a new dataset, follow these steps:
- Add the dataset loader in
data_set.py
by implementing a new class.
class Example_Dataset(Dataset):
def __init__(self,path,train_length = 64):
super().__init__()
self.train = [{'input':'What is the paper mainly about?','output':'A novel efficient inference framework'}]
self.train = [{'input':'What is the name of the framework?','output':'LARA'}]
- Implement the evaluation metrics in
results.py
.
class Example_Results(Results):
def post_processing(self):
candidates = ['A','B','C','D']
for i in range(len(self.results)):
truth = self.results[i]['truth']
output = self.results[i]['output']
try:
output = output.split('Question')[0]
except:pass
answer_output = None
for answer in candidates:
if output in answer or answer in output:
answer_output = answer
break
try:
self.results[i]['answer'] = answer_output.upper()
except:self.results[i]['answer'] = ''
def calculate_score(self):
self.post_processing()
count = 0
for result in self.results:
if result['truth']==result['answer']:count+=1
return count/len(self.results)
- Add the prompts in
prompt.py
, including the template to transform data to demostrations like
def get_example_prompt(datas, subject=None):
ans = ''
for data in datas:
ans += f"Question: {data['input']}\nAnswer: {data['output']}\n\n"
return ans
And define the prompt for question answering. Here, formula_string
represents the prompted in-context examples:
prompt_templates = {"example":lambda formula_string, input_string: f"{formula_string}Question: {input_string}\nAnswer:"}
Also, add your prompt in prompt.py
:
prompts = {'example':get_example_prompt}
- Add all these things in
utils.py
def get_path_dataset(output_path,dataset_name,subdataset_name,train_length):
if dataset_name == 'example':
dataset = Example_Dataset(subdataset_name,train_length)
filepath = f'results/{output_path}/{dataset_name}/{subdataset_name}/'
def get_results(dataset_name,results):
if dataset_name == 'example':
results = Example_Results(results)
The main code also provides ICL baseline results. The KATE results can be obtained by running retrieve_baseline.py
:
python retrieve_baseline.py --dataset_name BBH
The ablation study can be conducted using nonreweight.py
:
python nonreweight.py --dataset_name BBH`
The code base comes from https://github.com/eth-sri/language-model-arithmetic. We are very grateful to the author for providing an excellent code base to implement LARA. We also thank the open-source non-gradient optimization library, https://github.com/facebookresearch/nevergrad.
If you have questions, please open an issue or send an email to chengsong[at]wustl.edu.
Please cite our paper if you find the repo helpful in your work:
@article{Huang2024DivideRC,
title={Divide, Reweight, and Conquer: A Logit Arithmetic Approach for In-Context Learning},
author={Chengsong Huang and Langlin Huang and Jiaxin Huang},
year={2024},
journal={ArXiv},
volume={abs/2410.10074},
url={https://arxiv.org/abs/2410.10074},
}