Cookspert / app.py
Poonawala's picture
Update app.py
e87d259 verified
raw
history blame
5.62 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate import Accelerator
# Check if GPU is available for better performance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize the Accelerator for optimized inference
accelerator = Accelerator()
# Load models and tokenizers with FP16 for speed optimization if GPU is available
model_dirs = [
"Poonawala/gpt2",
"Poonawala/MiriFur",
"Poonawala/Llama-3.2-1B"
]
models = {}
tokenizers = {}
def load_model(model_dir):
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU/CPU as per availability
model = model.to(device)
return model, tokenizer
# Load all models
for model_dir in model_dirs:
model_name = model_dir.split("/")[-1]
try:
model, tokenizer = load_model(model_dir)
models[model_name] = model
tokenizers[model_name] = tokenizer
# Batch warm-up inference to reduce initial response time
dummy_inputs = ["Hello", "What is a recipe?", "Explain cooking basics"]
for dummy_input in dummy_inputs:
input_ids = tokenizer.encode(dummy_input, return_tensors='pt').to(device)
with torch.no_grad():
model.generate(input_ids, max_new_tokens=1)
print(f"Loaded model and tokenizer from {model_dir}.")
except Exception as e:
print(f"Failed to load model from {model_dir}: {e}")
continue
def get_response(prompt, model_name, user_type):
if model_name not in models:
return "Model not loaded correctly."
model = models[model_name]
tokenizer = tokenizers[model_name]
# Define different prompt templates based on user type
user_type_templates = {
"Expert": f"As an Expert, {prompt}\nAnswer:",
"Intermediate": f"As an Intermediate, {prompt}\nAnswer:",
"Beginner": f"Explain in simple terms: {prompt}\nAnswer:",
"Professional": f"As a Professional, {prompt}\nAnswer:"
}
# Get the appropriate prompt based on user type
prompt_template = user_type_templates.get(user_type, f"{prompt}\nAnswer:")
encoding = tokenizer(
prompt_template,
return_tensors='pt',
padding=True,
truncation=True,
max_length=500 # Increased length for larger inputs
).to(device)
max_new_tokens = 200 # Increased to allow full-length answers
with torch.no_grad():
output = model.generate(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
max_new_tokens=max_new_tokens, # Higher value for longer answers
num_beams=3, # Using beam search for better quality answers
repetition_penalty=1.2, # Increased to reduce repetitive text
temperature=0.9, # Slightly higher for creative outputs
top_p=0.9, # Including more tokens for diverse generation
early_stopping=True,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response.strip()
def process_input(prompt, model_name, user_type):
if prompt and prompt.strip():
return get_response(prompt, model_name, user_type)
else:
return "Please provide a prompt."
# Gradio Interface with Modern Design
with gr.Blocks(css="""
body {
background-color: #faf3e0; /* Beige for a warm food-related theme */
font-family: 'Arial, sans-serif';
}
.title {
font-size: 2.5rem;
font-weight: bold;
color: #ff7f50; /* Coral color for a food-inspired look */
text-align: center;
margin-bottom: 1rem;
}
.container {
max-width: 900px;
margin: auto;
padding: 2rem;
background-color: #ffffff;
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
.button {
background-color: #ff7f50; /* Coral color for buttons */
color: white;
padding: 0.8rem 1.5rem;
font-size: 1rem;
border: none;
border-radius: 5px;
cursor: pointer;
}
.button:hover {
background-color: #ffa07a; /* Light salmon for hover effect */
}
""") as demo:
gr.Markdown("<div class='title'>Cookspert: Your Cooking Assistant</div>")
user_types = ["Expert", "Intermediate", "Beginner", "Professional"]
with gr.Tabs():
with gr.TabItem("Ask a Cooking Question"):
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Ask about any recipe", placeholder="Ask question related to cooking here...", lines=2)
model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True)
user_type = gr.Dropdown(label="User Type", choices=user_types, value="Beginner")
submit_button = gr.Button("ChefGPT", elem_classes="button")
response = gr.Textbox(
label="🍽️ Response",
placeholder="Your answer will appear here...",
lines=10,
interactive=False,
show_copy_button=True
)
submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True, debug=True)