File size: 8,455 Bytes
2a01fa1
7f73a1c
2a01fa1
 
 
5d6783e
2a01fa1
896e071
 
 
 
d0b4b02
896e071
74d6bf5
2ccb1ee
5d6783e
58755ce
5d6783e
 
896e071
 
 
 
 
 
 
 
5d6783e
03948e3
5d6783e
aedf8bf
 
 
2ccb1ee
 
03948e3
5289607
aedf8bf
 
 
 
 
 
2ccb1ee
aedf8bf
 
 
 
 
2ccb1ee
 
 
 
 
 
 
 
 
 
 
 
 
 
5289607
 
2ccb1ee
 
aedf8bf
5289607
 
 
 
2ccb1ee
 
 
 
 
 
 
 
 
aedf8bf
2ccb1ee
 
 
 
58755ce
5d6783e
7f73a1c
 
 
 
 
 
 
 
 
5d6783e
7f73a1c
 
 
 
5d6783e
7f73a1c
 
 
 
5d6783e
7f73a1c
 
 
 
5d6783e
 
e7d74c5
7f73a1c
e7353da
98a403f
 
 
 
e7353da
7f73a1c
582c792
2a01fa1
 
582c792
 
2a01fa1
 
582c792
 
e7d74c5
 
 
582c792
2a01fa1
5d6783e
58755ce
5d6783e
 
2a01fa1
e7d74c5
 
5d6783e
 
e7d74c5
5d6783e
 
 
 
 
 
2a01fa1
5d6783e
 
 
 
 
 
 
 
 
 
 
 
58755ce
2a01fa1
846298d
2ccb1ee
582c792
 
aedf8bf
582c792
846298d
 
5d6783e
447b558
b8b21a5
846298d
5d6783e
 
aedf8bf
846298d
 
 
 
 
2a01fa1
 
5d6783e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import spaces
import re 
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import json

LEAN4_DEFAULT_HEADER = (
    "import Mathlib\n"
    "import Aesop\n\n"
    "set_option maxHeartbeats 0\n\n"
    "open BigOperators Real Nat Topology Rat\n"
)

title = """🙋🏻‍♂️Welcome to🌟Tonic's🔮Goedel Prover📉
You can build with this endpoint using🔮Goedel-Prover-SFT📉 available here : [Goedel-LM/Goedel-Prover-SFT](https://huggingface.co/Goedel-LM/Goedel-Prover-SFT)."""

def format_prompt(formal_statement, informal_prefix=""):
    """Format the input according to the Lean4 structure"""
    return (
        f"Complete the following Lean 4 code with explanatory comments preceding each line of code:\n\n"
        f"```lean4\n"
        f"{LEAN4_DEFAULT_HEADER}\n"
        f"{informal_prefix}\n"
        f"{formal_statement}"
    )
    
def extract_code(response):
    """Extract code between lean4 code blocks and the model's output"""
    try:
        # Find all content between ```lean4 and ``` tags
        matches = re.findall(r'```lean4(.*?)```', response, re.DOTALL)
        if not matches:
            # If no matches found in code blocks, return the full response
            return response.strip()
        
        # Get the last complete response including the model output
        full_content = matches[-1].strip()
        
        # Clean up any duplicate headers or content
        lines = full_content.split('\n')
        cleaned_lines = []
        seen_headers = set()
        header_done = False
        
        for line in lines:
            # Skip empty lines at the start
            if not cleaned_lines and not line.strip():
                continue
                
            # Process headers only once
            if not header_done and any(header in line for header in ["import Mathlib", "import Aesop", "set_option", "open BigOperators"]):
                if line not in seen_headers:
                    seen_headers.add(line)
                    cleaned_lines.append(line)
                continue
            
            # Mark header section as complete after processing imports and settings
            if not header_done and cleaned_lines:
                header_done = True
                cleaned_lines.append("")  # Add a blank line after headers
            
            # Include the line if it's not a duplicate goal statement
            if "Goal:" in line and line in cleaned_lines:
                continue
            if "Complete the following" not in line:
                cleaned_lines.append(line)
        
        # Remove any trailing "Complete the following..." text
        while cleaned_lines and "Complete the following" in cleaned_lines[-1]:
            cleaned_lines.pop()
            
        # Ensure the model's output is included
        if "===========================" in full_content:
            output_start = full_content.find("============================")
            if output_start != -1:
                output_text = full_content[output_start:].strip()
                if ":= by" in output_text:
                    proof_part = output_text[output_text.find(":= by"):]
                    cleaned_lines.append(proof_part)
        
        return '\n'.join(cleaned_lines)
    except Exception as e:
        print(f"Error in extract_code: {str(e)}")
        return "Error processing code"
        

# Example problems
unimath1 = """Goal:
  X : UU
  Y : UU
  P : UU
  xp : (X → P) → P
  yp : (Y → P) → P
  X0 : X × Y → P
  x : X
  ============================
   (Y → P)"""

unimath2 = """Goal:
    R : ring  M : module R
  ============================
   (islinear (idfun M))"""

unimath3 = """Goal:
    X : UU  i : nat  b : hProptoType (i < S i)  x : Vector X (S i)  r : i = i
  ============================
   (pr1 lastelement = pr1 (i,, b))"""

unimath4 = """Goal:
    X : dcpo  CX : continuous_dcpo_struct X  x : pr1hSet X  y : pr1hSet X
  ============================
   (x ⊑ y ≃ (∀ i : approximating_family CX x, approximating_family CX x i ⊑ y))"""

additional_info_prompt = "/-Explain using mathematics-/\n"

examples = [
    [unimath1, additional_info_prompt, 2500],
    [unimath2, additional_info_prompt, 2500],
    [unimath3, additional_info_prompt, 2500],
    [unimath4, additional_info_prompt, 2500]
]

model_name = "Goedel-LM/Goedel-Prover-SFT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Set generation config
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.bos_token_id = 100000
model.generation_config.eos_token_id = 100001
model.generation_config.do_sample = True
model.generation_config.temperature = 1.0
model.generation_config.top_p = 0.95

@spaces.GPU
def solve_math_problem(question, informal_prefix, max_tokens):
    # Format the prompt using Lean4 structure
    prompt = format_prompt(question, informal_prefix)
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    attention_mask = torch.ones_like(input_ids)
    
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_tokens + input_ids.shape[1],
        pad_token_id=model.generation_config.pad_token_id,
        temperature=1.0,
        top_p=0.95,
    )
    
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the full code from the response
    full_code = extract_code(prompt + result)
    
    # Create output dictionary similar to reference code
    output_data = {
        "model_input": prompt,
        "model_output": result,
        "full_code": full_code
    }
    
    return json.dumps(output_data, indent=2), full_code

def main():
    iface = gr.Interface(        
        title="🙋🏻‍♂️Welcome to🌟Tonic's🔮Goedel Prover📉",
        description="""You can build with this endpoint using🔮Goedel-Prover-SFT📉 available here : [Goedel-LM/Goedel-Prover-SFT](https://huggingface.co/Goedel-LM/Goedel-Prover-SFT). We're using 🤖[introspector/unimath](https://huggingface.co/datasets/introspector/unimath) for cool examples, check it out below ! The demo is still a work in progress and we're looking forward to build downstream tasks that showcase outstanding mathematical reasoning. Have any ideas ? join us below !
You can also use 🔮Goedel Prover📉 by cloning this space. Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/Math?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3> 
Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [Join us on Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) Math with [introspector](https://huggingface.co/introspector) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [SciTonic](https://github.com/Tonic-AI/scitonic)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
""",
        fn=solve_math_problem,
        outputs=[
            gr.JSON(label="Full Output"),
            gr.Code(label="Extracted Lean4 Code", language="python")
        ],
        inputs=[
            gr.Textbox(label="🤔Enter your Lean4 formal statement", lines=7),
            gr.Textbox(value=additional_info_prompt, label="🪜Optional informal prefix"),
            gr.Slider(minimum=150, maximum=4086, value=2500, label="🪙Max Tokens")
        ],
        examples=examples
    )

    iface.launch()

if __name__ == "__main__":
    main()