whyumesh commited on
Commit
5eb74ff
·
verified ·
1 Parent(s): d936529

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import spaces
5
+
6
+ # Load model and tokenizer
7
+ model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
8
+
9
+ def load_model():
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ torch_dtype=torch.float16,
13
+ device_map="auto"
14
+ )
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ return model, tokenizer
17
+
18
+ model, tokenizer = load_model()
19
+
20
+ @spaces.GPU(duration=60) # Adjust duration based on your needs
21
+ def fix_code(input_code):
22
+ # Prepare the prompt
23
+ messages = [
24
+ {"role": "system", "content": "You are a helpful coding assistant. Please analyze the following code, identify any errors, and provide the corrected version."},
25
+ {"role": "user", "content": f"Please fix this code:\n\n{input_code}"}
26
+ ]
27
+
28
+ # Apply chat template
29
+ text = tokenizer.apply_chat_template(
30
+ messages,
31
+ tokenize=False,
32
+ add_generation_prompt=True
33
+ )
34
+
35
+ # Tokenize and generate
36
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
37
+ generated_ids = model.generate(
38
+ **model_inputs,
39
+ max_new_tokens=1024,
40
+ temperature=0.7,
41
+ top_p=0.95,
42
+ )
43
+
44
+ # Decode only the new tokens
45
+ generated_ids = [
46
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
47
+ ]
48
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
+
50
+ return response
51
+
52
+ # Create Gradio interface
53
+ iface = gr.Interface(
54
+ fn=fix_code,
55
+ inputs=gr.Code(
56
+ label="Input Code",
57
+ language="python",
58
+ lines=10
59
+ ),
60
+ outputs=gr.Code(
61
+ label="Corrected Code",
62
+ language="python",
63
+ lines=10
64
+ ),
65
+ title="Code Correction Tool",
66
+ description="Enter your code with errors, and the AI will attempt to fix it.",
67
+ examples=[
68
+ ["def fibonacci(n):\n if n = 0:\n return 0\n elif n == 1\n return 1\n else:\n return fibonacci(n-1) + fibonacci(n-2)"],
69
+ ["for i in range(10)\n print(i"]
70
+ ]
71
+ )
72
+
73
+ if __name__ == "__main__":
74
+ iface.launch()