File size: 3,371 Bytes
02df9f8
5c3cea5
02df9f8
 
eaa0586
02df9f8
9bfa66b
02df9f8
 
 
 
9bfa66b
02df9f8
 
 
9428a07
 
 
 
02df9f8
 
 
 
 
39a2dae
eaa0586
 
 
 
8fa0ae4
70487ef
 
 
eaa0586
70487ef
 
eaa0586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fa0ae4
eaa0586
8fa0ae4
 
 
70487ef
eaa0586
02df9f8
 
 
3f861c3
513d0fe
 
3f861c3
9bfa66b
eaa0586
8fa0ae4
9bfa66b
1efd23b
 
9428a07
 
 
 
8fa0ae4
486c21f
 
eaa0586
02df9f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import spaces
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import time

model_name = 'yuntian-deng/gpt2-implicit-cot-multiplication'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def preprocess(num):
    num = str(num).strip().replace(' ', '')
    reversed_num = ' '.join(num[::-1])
    return reversed_num

def postprocess(raw_output):
    prediction = raw_output.replace(' ', '')[::-1]
    return prediction

@spaces.GPU
def predict_product(num1, num2):
    input_text = f'{preprocess(num1)} * {preprocess(num2)} ='
    inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu')
    model.to('cuda' if torch.cuda.is_available() else 'cpu')

    generated_ids = inputs['input_ids']
    prediction = ""
    correct_product = ""
    valid_input = True

    try:
        num1_int = int(num1)
        num2_int = int(num2)
        correct_product = str(num1_int * num2_int)
    except ValueError:
        valid_input = False

    for _ in range(40):  # Adjust the range to control the maximum number of generated tokens
        outputs = model.generate(generated_ids, max_new_tokens=1, do_sample=False)
        generated_ids = torch.cat((generated_ids, outputs[:, -1:]), dim=-1)
        output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        prediction = postprocess(output_text)
        
        result_html = "<div style='margin-bottom: 10px;'>Correct Result: " + " ".join(correct_product) + "</div><div>"
        for i, pred_digit in enumerate(prediction):
            color = "green" if i < len(correct_product) and pred_digit == correct_product[i] else "red"
            result_html += f"<span style='color: {color};'>{pred_digit}</span>"
        result_html += "</div>"

        yield result_html, ""

    if valid_input:
        is_correct = prediction == correct_product
        result_message = "Correct!" if is_correct else f"Incorrect! The correct product is {correct_product}."
    else:
        result_message = "Invalid input. Could not evaluate correctness."

    yield result_html, result_message

demo = gr.Interface(
    fn=predict_product,
    inputs=[
        gr.Textbox(label='First Number (up to 12 digits)', value='12345'),
        gr.Textbox(label='Second Number (up to 12 digits)', value='67890'),
    ],
    outputs=[
        gr.HTML(label='Predicted Product with Matching Digits Highlighted'),
        gr.HTML(label='Result Message')
    ],
    title='GPT2 Direct Multiplication Calculator (Without Using Chain-of-Thought)',
    description='This demo uses GPT2 to directly predict the product of two numbers without using any intermediate reasoning steps. The GPT2 model has been fine-tuned to internalize chain-of-thought reasoning within its hidden states, following our stepwise internalization approach detailed in the paper linked at the bottom of this page.',
    article="""
    - [Paper: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838)
    - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step)
    - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036)
    """,
    clear_btn=None,
    submit_btn="Multiply!",
    live=True
)

demo.launch()