import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch from accelerate import Accelerator # Check if GPU is available for better performancee device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Initialize the Accelerator for optimized inference accelerator = Accelerator() # Load models and tokenizers with FP16 for speed optimization if GPU is available model_dirs = [ "Poonawala/gpt2", "Poonawala/MiriFur", "Poonawala/Llama-3.2-1B" ] models = {} tokenizers = {} def load_model(model_dir): model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32) tokenizer = AutoTokenizer.from_pretrained(model_dir) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Move model to GPU/CPU as per availability model = model.to(device) return model, tokenizer # Load all models for model_dir in model_dirs: model_name = model_dir.split("/")[-1] try: model, tokenizer = load_model(model_dir) models[model_name] = model tokenizers[model_name] = tokenizer # Batch warm-up inference to reduce initial response time dummy_inputs = ["Hello", "What is a recipe?", "Explain cooking basics"] for dummy_input in dummy_inputs: input_ids = tokenizer.encode(dummy_input, return_tensors='pt').to(device) with torch.no_grad(): model.generate(input_ids, max_new_tokens=1) print(f"Loaded model and tokenizer from {model_dir}.") except Exception as e: print(f"Failed to load model from {model_dir}: {e}") continue def get_response(prompt, model_name, user_type): if model_name not in models: return "Model not loaded correctly." model = models[model_name] tokenizer = tokenizers[model_name] # Define different prompt templates based on user type user_type_templates = { "Expert": f"As an Expert, {prompt}\nAnswer:", "Intermediate": f"As an Intermediate, {prompt}\nAnswer:", "Beginner": f"Explain in simple terms: {prompt}\nAnswer:", "Professional": f"As a Professional, {prompt}\nAnswer:" } # Get the appropriate prompt based on user type prompt_template = user_type_templates.get(user_type, f"{prompt}\nAnswer:") encoding = tokenizer( prompt_template, return_tensors='pt', padding=True, truncation=True, max_length=500 # Increased length for larger inputs ).to(device) max_new_tokens = 200 # Increased to allow full-length answers with torch.no_grad(): output = model.generate( input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], max_new_tokens=max_new_tokens, # Higher value for longer answers num_beams=3, # Using beam search for better quality answers repetition_penalty=1.2, # Increased to reduce repetitive text temperature=0.9, # Slightly higher for creative outputs top_p=0.9, # Including more tokens for diverse generation early_stopping=True, pad_token_id=tokenizer.pad_token_id ) response = tokenizer.decode(output[0], skip_special_tokens=True) return response.strip() def process_input(prompt, model_name, user_type): if prompt and prompt.strip(): return get_response(prompt, model_name, user_type) else: return "Please provide a prompt." # Gradio Interface with Modern Design with gr.Blocks(css=""" body { background-color: #faf3e0; /* Beige for a warm food-related theme */ font-family: 'Arial, sans-serif'; } .title { font-size: 2.5rem; font-weight: bold; color: #ff7f50; /* Coral color for a food-inspired look */ text-align: center; margin-bottom: 1rem; } .container { max-width: 900px; margin: auto; padding: 2rem; background-color: #ffffff; border-radius: 10px; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); } .button { background-color: #ff7f50; /* Coral color for buttons */ color: white; padding: 0.8rem 1.5rem; font-size: 1rem; border: none; border-radius: 5px; cursor: pointer; } .button:hover { background-color: #ffa07a; /* Light salmon for hover effect */ } """) as demo: gr.Markdown("
Cookspert: Your Cooking Assistant
") user_types = ["Expert", "Intermediate", "Beginner", "Professional"] with gr.Tabs(): with gr.TabItem("Ask a Cooking Question"): with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox(label="Ask about any recipe", placeholder="Ask question related to cooking here...", lines=2) model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True) user_type = gr.Dropdown(label="User Type", choices=user_types, value="Beginner") submit_button = gr.Button("ChefGPT", elem_classes="button") response = gr.Textbox( label="🍽️ Response", placeholder="Your answer will appear here...", lines=10, interactive=False, show_copy_button=True ) submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", share=True, debug=True)