da03
.
dc59d8f
raw
history blame
7.69 kB
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load models
implicit_cot_model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name)
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name)
no_cot_model_name = 'yuntian-deng/gpt2-no-cot-multiplication'
no_cot_model = AutoModelForCausalLM.from_pretrained(no_cot_model_name)
explicit_cot_model_name = 'yuntian-deng/gpt2-explicit-cot-multiplication'
explicit_cot_model = AutoModelForCausalLM.from_pretrained(explicit_cot_model_name)
models = {'implicit': implicit_cot_model, 'no': no_cot_model, 'explicit': explicit_cot_model}
# Constants
MAX_PRODUCT_DIGITS_PER_MODEL = {'implicit': 100, 'no': 100, 'explicit': 900}
def preprocess(num):
num = str(num).strip().replace(' ', '')
reversed_num = ' '.join(num[::-1])
return reversed_num
def postprocess(raw_output):
prediction = raw_output.replace(' ', '')[::-1]
return prediction
@spaces.GPU
def predict_product(num1, num2):
input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
[model.to('cuda' if torch.cuda.is_available() else 'cpu') for model in models.values()]
input_ids = inputs['input_ids']
input_len = input_ids.shape[-1]
prediction = ""
ground_truth_product = ""
valid_input = True
try:
num1_int = int(num1)
num2_int = int(num2)
ground_truth_product = str(num1_int * num2_int)
ground_truth_digits_reversed = list(ground_truth_product)[::-1]
except ValueError:
valid_input = False
generated_ids_per_model = {model_name: inputs['input_ids'].data.clone() for model_name in models}
finished_per_model = {model_name: False for model_name in models}
past_key_values_per_model = {model_name: None for model_name in models}
predicted_annotations_per_model = {}
for step in range(max(MAX_PRODUCT_DIGITS_PER_MODEL.values())): # Set a maximum limit to prevent infinite loops
# Ground Truth
ground_truth_annotations = [(ground_truth_digit, None) for ground_truth_digit in ground_truth_digits_reversed[:step+1]]
ground_truth_annotations = ground_truth_annotations[::-1]
# Predicted
for model_name in models:
model = models[model_name]
if finished_per_model[model_name]:
continue
if step >= MAX_PRODUCT_DIGITS_PER_MODEL[model_name]:
continue
generation_kwargs = {
'input_ids': generated_ids_per_model[model_name],
'max_new_tokens': 1,
'do_sample': False,
'past_key_values': past_key_values_per_model[model_name],
'return_dict_in_generate': True,
'use_cache': True
}
if step == 0:
del generation_kwargs['past_key_values']
outputs = model.generate(**generation_kwargs)
generated_ids = outputs.sequences
next_token_id = generated_ids[0, -1]
#print (next_token_id)
if next_token_id.item() == tokenizer.eos_token_id:
finished_per_model[model_name] = True
continue
generated_ids_per_model[model_name] = generated_ids
past_key_values_per_model[model_name] = outputs.past_key_values
output_text = tokenizer.decode(generated_ids[0, input_len:], skip_special_tokens=True)
predicted_digits_reversed = output_text.strip().split(' ')
predicted_annotations = []
is_correct_sofar = True
if model_name == 'explicit':
if '=' not in predicted_digits_reversed:
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed]
predicted_digits_reversed = []
else:
equal_sign_position = predicted_digits_reversed.index('=')
predicted_annotations = [(predicted_digit, None) for predicted_digit in predicted_digits_reversed[:equal_sign_position+1]]
predicted_digits_reversed = predicted_digits_reversed[equal_sign_position+1:]
for i in range(len(predicted_digits_reversed)):
predicted_digit = predicted_digits_reversed[i]
if i >= len(ground_truth_digits_reversed):
if predicted_digit == '0' and is_correct_sofar:
is_correct_digit = True
else:
is_correct_digit = False
else:
ground_truth_digit = ground_truth_digits_reversed[i]
if predicted_digit == ground_truth_digit:
is_correct_digit = True
else:
is_correct_digit = False
if not is_correct_digit:
is_correct_sofar = False
if is_correct_digit:
predicted_annotations.append((predicted_digit, "correct"))
else:
predicted_annotations.append((predicted_digit, "wrong"))
predicted_annotations = predicted_annotations[::-1]
predicted_annotations_per_model[model_name] = predicted_annotations
predicted_annotations_implicit_cot = predicted_annotations_per_model['implicit']
predicted_annotations_nocot = predicted_annotations_per_model['no']
predicted_annotations_explicit_cot = predicted_annotations_per_model['explicit']
yield ground_truth_annotations, predicted_annotations_implicit_cot, predicted_annotations_nocot, predicted_annotations_explicit_cot
color_map = {"correct": "green", "wrong": "red"}
demo = gr.Interface(
fn=predict_product,
inputs=[
gr.Textbox(label='First Number (up to 15 digits)', value='123456789'),
gr.Textbox(label='Second Number (up to 15 digits)', value='987654321'),
],
outputs=[
gr.HighlightedText(label='Ground Truth Product', combine_adjacent=False, show_legend=False, color_map=color_map),
gr.HighlightedText(label='Implicit CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
gr.HighlightedText(label='No CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
gr.HighlightedText(label='Explicit CoT Steps & Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False),
],
title='Predicting Multiplication with GPT-2: Implicit vs. Explicit CoT',
description='This demo showcases GPT-2\'s ability to directly predict the product of two large numbers without intermediate steps, using our stepwise internalization method. Compare the performance of implicit CoT (our method), no CoT, and explicit CoT. Implicit CoT offers accuracy and speed, while explicit CoT provides detailed reasoning but is slower.',
article="""
- [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
""",
clear_btn=None,
submit_btn="Multiply!",
live=False,
concurrency_limit=1
)
demo.queue(max_size=20).launch()