|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from accelerate import Accelerator |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
model_dirs = [ |
|
"Poonawala/gpt2", |
|
"Poonawala/MiriFur", |
|
"Poonawala/Llama-3.2-1B" |
|
] |
|
|
|
models = {} |
|
tokenizers = {} |
|
|
|
def load_model(model_dir): |
|
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16 if device.type == "cuda" else torch.float32) |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = model.to(device) |
|
return model, tokenizer |
|
|
|
|
|
for model_dir in model_dirs: |
|
model_name = model_dir.split("/")[-1] |
|
try: |
|
model, tokenizer = load_model(model_dir) |
|
models[model_name] = model |
|
tokenizers[model_name] = tokenizer |
|
|
|
|
|
dummy_inputs = ["Hello", "What is a recipe?", "Explain cooking basics"] |
|
for dummy_input in dummy_inputs: |
|
input_ids = tokenizer.encode(dummy_input, return_tensors='pt').to(device) |
|
with torch.no_grad(): |
|
model.generate(input_ids, max_new_tokens=1) |
|
|
|
print(f"Loaded model and tokenizer from {model_dir}.") |
|
except Exception as e: |
|
print(f"Failed to load model from {model_dir}: {e}") |
|
continue |
|
|
|
def get_response(prompt, model_name, user_type): |
|
if model_name not in models: |
|
return "Model not loaded correctly." |
|
|
|
model = models[model_name] |
|
tokenizer = tokenizers[model_name] |
|
|
|
|
|
user_type_templates = { |
|
"Expert": f"As an Expert, {prompt}\nAnswer:", |
|
"Intermediate": f"As an Intermediate, {prompt}\nAnswer:", |
|
"Beginner": f"Explain in simple terms: {prompt}\nAnswer:", |
|
"Professional": f"As a Professional, {prompt}\nAnswer:" |
|
} |
|
|
|
|
|
prompt_template = user_type_templates.get(user_type, f"{prompt}\nAnswer:") |
|
|
|
encoding = tokenizer( |
|
prompt_template, |
|
return_tensors='pt', |
|
padding=True, |
|
truncation=True, |
|
max_length=500 |
|
).to(device) |
|
|
|
max_new_tokens = 200 |
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
input_ids=encoding['input_ids'], |
|
attention_mask=encoding['attention_mask'], |
|
max_new_tokens=max_new_tokens, |
|
num_beams=3, |
|
repetition_penalty=1.2, |
|
temperature=0.9, |
|
top_p=0.9, |
|
early_stopping=True, |
|
pad_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return response.strip() |
|
|
|
def process_input(prompt, model_name, user_type): |
|
if prompt and prompt.strip(): |
|
return get_response(prompt, model_name, user_type) |
|
else: |
|
return "Please provide a prompt." |
|
|
|
|
|
with gr.Blocks(css=""" |
|
body { |
|
background-color: #faf3e0; /* Beige for a warm food-related theme */ |
|
font-family: 'Arial, sans-serif'; |
|
} |
|
.title { |
|
font-size: 2.5rem; |
|
font-weight: bold; |
|
color: #ff7f50; /* Coral color for a food-inspired look */ |
|
text-align: center; |
|
margin-bottom: 1rem; |
|
} |
|
.container { |
|
max-width: 900px; |
|
margin: auto; |
|
padding: 2rem; |
|
background-color: #ffffff; |
|
border-radius: 10px; |
|
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1); |
|
} |
|
.button { |
|
background-color: #ff7f50; /* Coral color for buttons */ |
|
color: white; |
|
padding: 0.8rem 1.5rem; |
|
font-size: 1rem; |
|
border: none; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
} |
|
.button:hover { |
|
background-color: #ffa07a; /* Light salmon for hover effect */ |
|
} |
|
""") as demo: |
|
|
|
gr.Markdown("<div class='title'>Cookspert: Your Cooking Assistant</div>") |
|
|
|
user_types = ["Expert", "Intermediate", "Beginner", "Professional"] |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Ask a Cooking Question"): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
prompt = gr.Textbox(label="Ask about any recipe", placeholder="Ask question related to cooking here...", lines=2) |
|
model_name = gr.Radio(label="Choose Model", choices=list(models.keys()), interactive=True) |
|
user_type = gr.Dropdown(label="User Type", choices=user_types, value="Beginner") |
|
submit_button = gr.Button("ChefGPT", elem_classes="button") |
|
|
|
response = gr.Textbox( |
|
label="🍽️ Response", |
|
placeholder="Your answer will appear here...", |
|
lines=10, |
|
interactive=False, |
|
show_copy_button=True |
|
) |
|
|
|
submit_button.click(fn=process_input, inputs=[prompt, model_name, user_type], outputs=response) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(server_name="0.0.0.0", share=True, debug=True) |
|
|