Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load models | |
implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication-20-digits' | |
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name) | |
no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication' | |
no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name) | |
explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication-20-digits' | |
explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name) | |
models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model} | |
[model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()] | |
[model.eval() for model in models.values()] | |
# Constants | |
#MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 960} | |
MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 1070} | |
def preprocess(num): | |
num = str(num).strip().replace(' ', '') | |
reversed_num = ' '.join(num[::-1]) | |
return reversed_num | |
def postprocess(raw_output): | |
prediction = raw_output.replace(' ', '')[::-1] | |
return prediction | |
def predict_product(num1, num2): | |
input_text = f'{preprocess(num1)} * {preprocess(num2)} =' | |
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu') | |
[model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()] | |
input_ids = inputs['input_ids'] | |
input_len = input_ids.shape[-1] | |
prediction = "" | |
ground_truth_product = "" | |
valid_input = True | |
try: | |
num1_int = int(num1) | |
num2_int = int(num2) | |
ground_truth_product = str(num1_int * num2_int) | |
ground_truth_digits_reversed = list(ground_truth_product)[::-1] | |
except ValueError: | |
valid_input = False | |
generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models} | |
finished_per_model = {model_name: False for model_name in models} | |
past_key_values_per_model = {model_name: None for model_name in models} | |
predicted_annotations_per_model = {} | |
for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops | |
# Ground Truth | |
if not valid_input: | |
ground_truth_annotations = [('Invalid Input!', None)] | |
else: | |
ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]] | |
ground_truth_annotations = ground_truth_annotations[::-1] | |
# Predicted | |
for model_name in models: | |
model = models[model_name] | |
if finished_per_model[model_name]: | |
continue | |
if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]: | |
continue | |
generation_kwargs = { | |
'input_ids': generated_ids_per_model[model_name], | |
'max_new_tokens': 1, | |
'do_sample': False, | |
'past_key_values': past_key_values_per_model[model_name], | |
'return_dict_in_generate': True, | |
'use_cache': True | |
} | |
if step == 0: | |
del generation_kwargs['past_key_values'] | |
outputs = model.generate(**generation_kwargs) | |
generated_ids = outputs.sequences | |
next_token_id = generated_ids[0, -1] | |
#print (next_token_id) | |
if next_token_id.item() == tokenizer.eos_token_id: | |
finished_per_model[model_name] = True | |
if valid_input: | |
if len([item for item in predicted_annotations_per_model[model_name] if item[1] is not None]) < len(ground_truth_digits_reversed): | |
predicted_annotations_per_model[model_name].insert(0, ('⠀', 'wrong')) | |
continue | |
generated_ids_per_model[model_name] = generated_ids | |
past_key_values_per_model[model_name] = outputs.past_key_values | |
output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True) | |
predicted_digits_reversed = output_text.strip().split(' ') | |
predicted_annotations = [] | |
is_correct_sofar = True | |
if model_name == 'explicit': | |
if '=' not in predicted_digits_reversed: | |
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed] | |
predicted_digits_reversed = [] | |
else: | |
equal_sign_position = predicted_digits_reversed.index('=') | |
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]] | |
predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:] | |
for i in range(len(predicted_digits_reversed)): | |
predicted_digit = predicted_digits_reversed[i] | |
if not valid_input: | |
is_correct_digit = None | |
elif i >= len(ground_truth_digits_reversed): | |
if predicted_digit == '0' and is_correct_sofar: | |
is_correct_digit = True | |
else: | |
is_correct_digit = False | |
else: | |
ground_truth_digit = ground_truth_digits_reversed[i] | |
if predicted_digit == ground_truth_digit: | |
is_correct_digit = True | |
else: | |
is_correct_digit = False | |
if not is_correct_digit: | |
is_correct_sofar = False | |
if is_correct_digit is None: | |
predicted_annotations.append((predicted_digit, None)) | |
elif is_correct_digit: | |
predicted_annotations.append((predicted_digit, "correct")) | |
else: | |
predicted_annotations.append((predicted_digit, "wrong")) | |
predicted_annotations = predicted_annotations[::-1] | |
predicted_annotations_per_model[model_name] = predicted_annotations | |
predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit'] | |
predicted_annotations_nocot = predicted_annotations_per_model['no'] | |
predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit'] | |
yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot | |
color_map = {"correct": "green", "wrong": "red"} | |
demo = gr.Interface( | |
fn=predict_product, | |
inputs=[ | |
gr.Textbox(label='First Number (up to 20 digits)', value='12345678912345678912'), | |
gr.Textbox(label='Second Number (up to 20 digits)', value='98765432198765432198'), | |
], | |
outputs=[ | |
gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map), | |
gr.HighlightedText(label='Implicit CoT Prediction (Ours)', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
gr.HighlightedText(label='No CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
gr.HighlightedText(label='Explicit CoT Steps & Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
], | |
title='Predicting Multiplication with GPT-2: Implicit vs. Explicit CoT', | |
description='This demo showcases GPT-2\'s ability to directly predict the product of two large numbers without intermediate steps, using our stepwise internalization method. Compare the performance of implicit CoT (our method), no CoT, and explicit CoT. Implicit CoT offers accuracy and speed, while explicit CoT provides detailed reasoning but is slower.', | |
article=""" | |
- [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460) | |
- [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838) | |
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step) | |
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036) | |
""", | |
clear_btn=None, | |
submit_btn="Multiply!", | |
live=False, | |
concurrency_limit=1 | |
) | |
demo.queue(max_size=20).launch() | |