Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import openai | |
import torch | |
# Load Llama model (GPU-optimized) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
llama_model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Llama-2-7b-chat-hf", | |
device_map="auto" | |
) | |
# OpenAI GPT Model API Key (Replace with your API key) | |
openai.api_key = "YOUR_OPENAI_API_KEY" | |
# Function to query Llama | |
def query_llama(prompt): | |
inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device) | |
outputs = llama_model.generate(inputs.input_ids, max_length=150) | |
response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Function to query GPT | |
def query_gpt(prompt): | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
max_tokens=150 | |
) | |
return response['choices'][0]['text'].strip() | |
# Function to compare models | |
def compare_models(prompt, models): | |
responses = {} | |
if "Llama" in models: | |
responses["Llama"] = query_llama(prompt) | |
if "GPT" in models: | |
responses["GPT"] = query_gpt(prompt) | |
return responses | |
# Gradio Interface | |
def gradio_app(): | |
with gr.Blocks() as app: | |
gr.Markdown("# AI Model Comparison Tool π") | |
with gr.Row(): | |
prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Ask something...") | |
with gr.Row(): | |
model_selector = gr.CheckboxGroup( | |
["Llama", "GPT"], | |
label="Select Models to Compare", | |
value=["Llama", "GPT"] | |
) | |
with gr.Row(): | |
output_boxes = gr.JSON(label="Model Responses") | |
with gr.Row(): | |
compare_button = gr.Button("Compare Models") | |
compare_button.click(compare_models, inputs=[prompt_input, model_selector], outputs=[output_boxes]) | |
return app | |
if __name__ == "__main__": | |
gradio_app().launch() | |