Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
3 |
import torch
|
4 |
from accelerate import Accelerator
|
5 |
|
6 |
-
# Check if GPU is available for better
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
print(f"Using device: {device}")
|
9 |
|
@@ -12,9 +12,9 @@ accelerator = Accelerator()
|
|
12 |
|
13 |
# Load models and tokenizers with FP16 for speed optimization if GPU is available
|
14 |
model_dirs = [
|
15 |
-
"
|
16 |
-
"
|
17 |
-
"
|
18 |
]
|
19 |
|
20 |
models = {}
|
@@ -23,10 +23,10 @@ tokenizers = {}
|
|
23 |
def load_model(model_dir):
|
24 |
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
26 |
-
|
27 |
if tokenizer.pad_token is None:
|
28 |
tokenizer.pad_token = tokenizer.eos_token
|
29 |
-
|
30 |
# Move model to GPU/CPU as per availability
|
31 |
model = model.to(device)
|
32 |
return model, tokenizer
|
@@ -54,16 +54,16 @@ for model_dir in model_dirs:
|
|
54 |
def get_response(prompt, model_name, user_type):
|
55 |
if model_name not in models:
|
56 |
return "Model not loaded correctly."
|
57 |
-
|
58 |
model = models[model_name]
|
59 |
tokenizer = tokenizers[model_name]
|
60 |
|
61 |
# Define different prompt templates based on user type
|
62 |
user_type_templates = {
|
63 |
-
"
|
64 |
-
"
|
65 |
"Beginner": f"Explain in simple terms: {prompt}\nAnswer:",
|
66 |
-
"
|
67 |
}
|
68 |
|
69 |
# Get the appropriate prompt based on user type
|
@@ -74,20 +74,21 @@ def get_response(prompt, model_name, user_type):
|
|
74 |
return_tensors='pt',
|
75 |
padding=True,
|
76 |
truncation=True,
|
77 |
-
max_length=
|
78 |
).to(device)
|
79 |
|
80 |
-
|
|
|
81 |
|
82 |
with torch.no_grad():
|
83 |
output = model.generate(
|
84 |
input_ids=encoding['input_ids'],
|
85 |
attention_mask=encoding['attention_mask'],
|
86 |
-
max_new_tokens=max_new_tokens,
|
87 |
-
num_beams=
|
88 |
-
repetition_penalty=1.
|
89 |
-
temperature=0.
|
90 |
-
top_p=0.
|
91 |
early_stopping=True,
|
92 |
pad_token_id=tokenizer.pad_token_id
|
93 |
)
|
@@ -99,65 +100,106 @@ def process_input(prompt, model_name, user_type):
|
|
99 |
if prompt and prompt.strip():
|
100 |
return get_response(prompt, model_name, user_type)
|
101 |
else:
|
102 |
-
return "Please provide a prompt."
|
103 |
|
104 |
# Gradio Interface with Modern Design
|
105 |
-
with gr.Blocks(css="""
|
106 |
body {
|
107 |
-
background-color: #
|
108 |
-
font-family: 'Arial, sans-serif
|
109 |
}
|
|
|
110 |
.title {
|
111 |
-
font-size: 2.
|
112 |
-
font-weight:
|
113 |
-
color: #
|
114 |
text-align: center;
|
115 |
-
margin-bottom:
|
116 |
}
|
|
|
117 |
.container {
|
118 |
-
max-width:
|
119 |
margin: auto;
|
120 |
padding: 2rem;
|
121 |
background-color: #ffffff;
|
122 |
border-radius: 10px;
|
123 |
-
box-shadow: 0
|
124 |
}
|
|
|
125 |
.button {
|
126 |
-
background-color: #
|
127 |
color: white;
|
128 |
-
padding: 0.8rem 1.
|
129 |
-
font-size: 1rem;
|
130 |
border: none;
|
131 |
-
border-radius:
|
132 |
cursor: pointer;
|
|
|
|
|
|
|
133 |
}
|
|
|
134 |
.button:hover {
|
135 |
-
background-color: #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
}
|
137 |
""") as demo:
|
138 |
|
139 |
-
gr.Markdown("<div class='title'>Cookspert: Your
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
with gr.Row():
|
146 |
-
with gr.Column(scale=2):
|
147 |
-
prompt = gr.Textbox(label="Ask about any recipe", placeholder="Ask question related to cooking here...", lines=2)
|
148 |
-
model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True)
|
149 |
-
user_type = gr.Dropdown(label="User Type", choices=user_types, value="Beginner")
|
150 |
-
submit_button = gr.Button("ChefGPT", elem_classes="button")
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
-
|
161 |
|
162 |
if __name__ == "__main__":
|
163 |
demo.launch(server_name="0.0.0.0", share=True, debug=True)
|
|
|
3 |
import torch
|
4 |
from accelerate import Accelerator
|
5 |
|
6 |
+
# Check if GPU is available for better performance
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
print(f"Using device: {device}")
|
9 |
|
|
|
12 |
|
13 |
# Load models and tokenizers with FP16 for speed optimization if GPU is available
|
14 |
model_dirs = [
|
15 |
+
"muhammadAhmed22/fine_tuned_gpt2",
|
16 |
+
"muhammadAhmed22/MiriFurgpt2-recipes",
|
17 |
+
"muhammadAhmed22/auhide-chef-gpt-en"
|
18 |
]
|
19 |
|
20 |
models = {}
|
|
|
23 |
def load_model(model_dir):
|
24 |
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
|
25 |
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
26 |
+
|
27 |
if tokenizer.pad_token is None:
|
28 |
tokenizer.pad_token = tokenizer.eos_token
|
29 |
+
|
30 |
# Move model to GPU/CPU as per availability
|
31 |
model = model.to(device)
|
32 |
return model, tokenizer
|
|
|
54 |
def get_response(prompt, model_name, user_type):
|
55 |
if model_name not in models:
|
56 |
return "Model not loaded correctly."
|
57 |
+
|
58 |
model = models[model_name]
|
59 |
tokenizer = tokenizers[model_name]
|
60 |
|
61 |
# Define different prompt templates based on user type
|
62 |
user_type_templates = {
|
63 |
+
"Professional Chef": f"As a professional chef, {prompt}\nAnswer:",
|
64 |
+
"Home Cook": f"As a home cook, {prompt}\nAnswer:",
|
65 |
"Beginner": f"Explain in simple terms: {prompt}\nAnswer:",
|
66 |
+
"Food Enthusiast": f"As a food enthusiast, {prompt}\nAnswer:"
|
67 |
}
|
68 |
|
69 |
# Get the appropriate prompt based on user type
|
|
|
74 |
return_tensors='pt',
|
75 |
padding=True,
|
76 |
truncation=True,
|
77 |
+
max_length=300 # Increased length for larger inputs
|
78 |
).to(device)
|
79 |
|
80 |
+
# Reduce max_new_tokens for faster response time
|
81 |
+
max_new_tokens = 50 # Reduced to speed up response time
|
82 |
|
83 |
with torch.no_grad():
|
84 |
output = model.generate(
|
85 |
input_ids=encoding['input_ids'],
|
86 |
attention_mask=encoding['attention_mask'],
|
87 |
+
max_new_tokens=max_new_tokens,
|
88 |
+
num_beams=1, # Using greedy decoding (faster)
|
89 |
+
repetition_penalty=1.1,
|
90 |
+
temperature=0.7, # Slightly reduced for better performance
|
91 |
+
top_p=0.85, # Reduced top_p for faster results
|
92 |
early_stopping=True,
|
93 |
pad_token_id=tokenizer.pad_token_id
|
94 |
)
|
|
|
100 |
if prompt and prompt.strip():
|
101 |
return get_response(prompt, model_name, user_type)
|
102 |
else:
|
103 |
+
return "Please provide a valid prompt."
|
104 |
|
105 |
# Gradio Interface with Modern Design
|
106 |
+
with gr.Blocks(css="""
|
107 |
body {
|
108 |
+
background-color: #f8f8f8;
|
109 |
+
font-family: 'Helvetica Neue', Arial, sans-serif;
|
110 |
}
|
111 |
+
|
112 |
.title {
|
113 |
+
font-size: 2.6rem;
|
114 |
+
font-weight: 700;
|
115 |
+
color: #ff6347;
|
116 |
text-align: center;
|
117 |
+
margin-bottom: 1.5rem;
|
118 |
}
|
119 |
+
|
120 |
.container {
|
121 |
+
max-width: 800px;
|
122 |
margin: auto;
|
123 |
padding: 2rem;
|
124 |
background-color: #ffffff;
|
125 |
border-radius: 10px;
|
126 |
+
box-shadow: 0 12px 24px rgba(0, 0, 0, 0.1);
|
127 |
}
|
128 |
+
|
129 |
.button {
|
130 |
+
background-color: #ff6347;
|
131 |
color: white;
|
132 |
+
padding: 0.8rem 1.8rem;
|
133 |
+
font-size: 1.1rem;
|
134 |
border: none;
|
135 |
+
border-radius: 8px;
|
136 |
cursor: pointer;
|
137 |
+
transition: background-color 0.3s ease;
|
138 |
+
margin-top: 1.5rem;
|
139 |
+
width: 100%;
|
140 |
}
|
141 |
+
|
142 |
.button:hover {
|
143 |
+
background-color: #ff4500;
|
144 |
+
}
|
145 |
+
|
146 |
+
.gradio-interface .gr-textbox {
|
147 |
+
margin-bottom: 1.5rem;
|
148 |
+
width: 100%;
|
149 |
+
border-radius: 8px;
|
150 |
+
padding: 1rem;
|
151 |
+
border: 1px solid #ddd;
|
152 |
+
font-size: 1rem;
|
153 |
+
background-color: #f9f9f9;
|
154 |
+
color: #333;
|
155 |
+
}
|
156 |
+
|
157 |
+
.gradio-interface .gr-radio, .gradio-interface .gr-dropdown {
|
158 |
+
margin-bottom: 1.5rem;
|
159 |
+
width: 100%;
|
160 |
+
border-radius: 8px;
|
161 |
+
padding: 1rem;
|
162 |
+
border: 1px solid #ddd;
|
163 |
+
background-color: #f9f9f9;
|
164 |
+
font-size: 1rem;
|
165 |
+
color: #333;
|
166 |
+
}
|
167 |
+
|
168 |
+
.gradio-interface .gr-textbox[readonly] {
|
169 |
+
background-color: #f5f5f5;
|
170 |
+
color: #333;
|
171 |
+
font-size: 1rem;
|
172 |
}
|
173 |
""") as demo:
|
174 |
|
175 |
+
gr.Markdown("<div class='title'>Cookspert: Your Personal AI Chef</div>")
|
176 |
+
|
177 |
+
user_types = ["Professional", "Beginner", "Intermediate", "Expert"]
|
178 |
+
|
179 |
+
with gr.Column(scale=1, min_width=350):
|
180 |
+
# Prompt Section
|
181 |
+
prompt = gr.Textbox(label="Enter Your Cooking Question", placeholder="What would you like to ask?", lines=3)
|
182 |
|
183 |
+
# Model Selection Section
|
184 |
+
model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True)
|
185 |
|
186 |
+
# User Type Selection
|
187 |
+
user_type = gr.Dropdown(label="Select Your Skill Level", choices=user_types, value="Home Cook")
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
+
# Submit Button
|
190 |
+
submit_button = gr.Button("chef gpt", elem_classes="button")
|
191 |
+
|
192 |
+
# Response Section
|
193 |
+
response = gr.Textbox(
|
194 |
+
label="Response",
|
195 |
+
placeholder="Your answer will appear here...",
|
196 |
+
lines=10,
|
197 |
+
interactive=False,
|
198 |
+
show_copy_button=True,
|
199 |
+
max_lines=12
|
200 |
+
)
|
201 |
|
202 |
+
submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response)
|
203 |
|
204 |
if __name__ == "__main__":
|
205 |
demo.launch(server_name="0.0.0.0", share=True, debug=True)
|