Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,798 Bytes
02df9f8 5c3cea5 02df9f8 9bfa66b 02df9f8 9bfa66b 02df9f8 9428a07 02df9f8 39a2dae eaa0586 8fa0ae4 70487ef eaa0586 70487ef eaa0586 6cc23f5 eaa0586 6cc23f5 eaa0586 8fa0ae4 eaa0586 8fa0ae4 70487ef 6cc23f5 02df9f8 3f861c3 513d0fe 3f861c3 9bfa66b 6cc23f5 8fa0ae4 9bfa66b 1efd23b 9428a07 8fa0ae4 486c21f 6cc23f5 02df9f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def preprocess(num):
num = str(num).strip().replace(' ', '')
reversed_num = ' '.join(num[::-1])
return reversed_num
def postprocess(raw_output):
prediction = raw_output.replace(' ', '')[::-1]
return prediction
@spaces.GPU
def predict_product(num1, num2):
input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
model.to('cuda' if torch.cuda.is_available() else 'cpu')
generated_ids = inputs['input_ids']
prediction = ""
correct_product = ""
valid_input = True
try:
num1_int = int(num1)
num2_int = int(num2)
correct_product = str(num1_int * num2_int)
except ValueError:
valid_input = False
for _ in range(40): # Adjust the range to control the maximum number of generated tokens
outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False)
generated_ids = torch.cat((generated_ids, outputs[:, -1:]), dim=-1)
output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
prediction = postprocess(output_text)
# Manually create the diff for HighlightedText
diff = []
for i in range(len(prediction)):
if i < len(correct_product) and prediction[i] == correct_product[i]:
diff.append((prediction[i], None)) # No highlight for correct digits
else:
diff.append((prediction[i], "+")) # Highlight incorrect digits in red
yield diff, ""
if valid_input:
is_correct = prediction == correct_product
result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
else:
result_message = "Invalid input. Could not evaluate correctness."
# Final diff for the complete prediction
final_diff = []
for i in range(len(prediction)):
if i < len(correct_product) and prediction[i] == correct_product[i]:
final_diff.append((prediction[i], None)) # No highlight for correct digits
else:
final_diff.append((prediction[i], "+")) # Highlight incorrect digits in red
yield final_diff, result_message
demo = gr.Interface(
fn=predict_product,
inputs=[
gr.Textbox(label='First Number (up to 12 digits)', value='12345'),
gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
],
outputs=[
gr.HighlightedText(label='Predicted Product with Matching Digits Highlighted', combine_adjacent=True, show_legend=True, color_map={"+": "red"}),
gr.HTML(label='Result Message')
],
title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
description='This demo uses GPT2 to directly predict the product of two numbers without using any intermediate reasoning steps. The GPT2 model has been fine-tuned to internalize chain-of-thought reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.',
article="""
- [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
""",
clear_btn=None,
submit_btn="Multiply!",
live=False
)
demo.launch()
|