File size: 5,621 Bytes
e87d259 630bed8 e87d259 630bed8 e87d259 630bed8 e87d259 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate import Accelerator
# Check if GPU is available for better performance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize the Accelerator for optimized inference
accelerator = Accelerator()
# Load models and tokenizers with FP16 for speed optimization if GPU is available
model_dirs = [
"Poonawala/gpt2",
"Poonawala/MiriFur",
"Poonawala/Llama-3.2-1B"
]
models = {}
tokenizers = {}
def load_model(model_dir):
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU/CPU as per availability
model = model.to(device)
return model, tokenizer
# Load all models
for model_dir in model_dirs:
model_name = model_dir.split("/")[-1]
try:
model, tokenizer = load_model(model_dir)
models[model_name] = model
tokenizers[model_name] = tokenizer
# Batch warm-up inference to reduce initial response time
dummy_inputs = ["Hello", "What is a recipe?", "Explain cooking basics"]
for dummy_input in dummy_inputs:
input_ids = tokenizer.encode(dummy_input, return_tensors='pt').to(device)
with torch.no_grad():
model.generate(input_ids, max_new_tokens=1)
print(f"Loaded model and tokenizer from {model_dir}.")
except Exception as e:
print(f"Failed to load model from {model_dir}: {e}")
continue
def get_response(prompt, model_name, user_type):
if model_name not in models:
return "Model not loaded correctly."
model = models[model_name]
tokenizer = tokenizers[model_name]
# Define different prompt templates based on user type
user_type_templates = {
"Expert": f"As an Expert, {prompt}\nAnswer:",
"Intermediate": f"As an Intermediate, {prompt}\nAnswer:",
"Beginner": f"Explain in simple terms: {prompt}\nAnswer:",
"Professional": f"As a Professional, {prompt}\nAnswer:"
}
# Get the appropriate prompt based on user type
prompt_template = user_type_templates.get(user_type, f"{prompt}\nAnswer:")
encoding = tokenizer(
prompt_template,
return_tensors='pt',
padding=True,
truncation=True,
max_length=500 # Increased length for larger inputs
).to(device)
max_new_tokens = 200 # Increased to allow full-length answers
with torch.no_grad():
output = model.generate(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
max_new_tokens=max_new_tokens, # Higher value for longer answers
num_beams=3, # Using beam search for better quality answers
repetition_penalty=1.2, # Increased to reduce repetitive text
temperature=0.9, # Slightly higher for creative outputs
top_p=0.9, # Including more tokens for diverse generation
early_stopping=True,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response.strip()
def process_input(prompt, model_name, user_type):
if prompt and prompt.strip():
return get_response(prompt, model_name, user_type)
else:
return "Please provide a prompt."
# Gradio Interface with Modern Design
with gr.Blocks(css="""
body {
background-color: #faf3e0; /* Beige for a warm food-related theme */
font-family: 'Arial, sans-serif';
}
.title {
font-size: 2.5rem;
font-weight: bold;
color: #ff7f50; /* Coral color for a food-inspired look */
text-align: center;
margin-bottom: 1rem;
}
.container {
max-width: 900px;
margin: auto;
padding: 2rem;
background-color: #ffffff;
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
.button {
background-color: #ff7f50; /* Coral color for buttons */
color: white;
padding: 0.8rem 1.5rem;
font-size: 1rem;
border: none;
border-radius: 5px;
cursor: pointer;
}
.button:hover {
background-color: #ffa07a; /* Light salmon for hover effect */
}
""") as demo:
gr.Markdown("<div class='title'>Cookspert: Your Cooking Assistant</div>")
user_types = ["Expert", "Intermediate", "Beginner", "Professional"]
with gr.Tabs():
with gr.TabItem("Ask a Cooking Question"):
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="Ask about any recipe", placeholder="Ask question related to cooking here...", lines=2)
model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True)
user_type = gr.Dropdown(label="User Type", choices=user_types, value="Beginner")
submit_button = gr.Button("ChefGPT", elem_classes="button")
response = gr.Textbox(
label="🍽️ Response",
placeholder="Your answer will appear here...",
lines=10,
interactive=False,
show_copy_button=True
)
submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=True, debug=True)
|