File size: 3,145 Bytes
02df9f8
5c3cea5
02df9f8
 
 
9bfa66b
02df9f8
 
 
 
9bfa66b
02df9f8
 
 
9428a07
 
 
 
02df9f8
 
39a2dae
02df9f8
39a2dae
02df9f8
 
39a2dae
 
02df9f8
39a2dae
9428a07
 
 
8fa0ae4
39a2dae
70487ef
 
 
 
 
 
8fa0ae4
2b1b4f1
 
89d08ae
 
2b1b4f1
 
8fa0ae4
 
 
 
 
70487ef
 
 
02df9f8
 
 
3f861c3
513d0fe
 
3f861c3
9bfa66b
b423680
 
8fa0ae4
 
9bfa66b
b2ef87d
 
9428a07
 
 
 
8fa0ae4
3f861c3
02df9f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def preprocess(num):
    num = str(num).strip().replace(' ', '')
    reversed_num = ' '.join(num[::-1])
    return reversed_num

def postprocess(raw_output):
    prediction = raw_output.replace(' ', '')[::-1]
    return prediction

@spaces.GPU
def predict_product(num1, num2):
    # Reverse input digits and add spaces
    input_text = f'{preprocess(num1)} * {preprocess(num2)} ='

    inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

    # Generate output
    outputs = model.generate(**inputs, max_new_tokens=40)

    output = outputs[0][inputs['input_ids'].shape[-1]:]
    raw_output = tokenizer.decode(output, skip_special_tokens=True)
    prediction = postprocess(raw_output)

    # Evalaute the correctness of the result
    try:
        num1_int = int(num1)
        num2_int = int(num2)
        valid_input = True
    except ValueError:
        valid_input = False
    if valid_input:
        correct_product = num1_int * num2_int
        try:
            prediction_int = int(prediction)
            is_correct = (prediction_int == correct_product)
        except ValueError:
            is_correct = False
        result_color = "green" if is_correct else "red"
        result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
    else:
        result_color = "black"
        result_message = "Invalid input. Could not evaluate correctness."
    result_html = f"<div style='color: {result_color};'>{result_message}</div>"

    return input_text, raw_output, prediction, result_html

demo = gr.Interface(
    fn=predict_product,
    inputs=[
        gr.Textbox(label='First Number (up to 12 digits)', value='12345'),
        gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
    ],
    outputs=[
        gr.Textbox(label='Raw Input to GPT-2 (reversed digits and added spaces)'),
        gr.Textbox(label='Raw Output from GPT-2 (reversed digits and with spaces)'),
        gr.Textbox(label='Predicted Product'),
        gr.HTML(label='Result Message')
    ],
    title='GPT-2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
    description='This demo uses GPT-2 to directly predict the product of two numbers without using any intermediate steps. The GPT-2 is finetuned to internalize chain-of-thought reasoning in its hidden states, using our stepwise internalization approach detailed in the paper below.',
    article="""
    - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
    - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
    - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
    """,
    live=False
)

demo.launch()