diabolic6045's picture
Update app.py
1fc00b0 verified
raw
history blame
3.48 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("diabolic6045/ELN-Llama-1B-base")
model = AutoModelForCausalLM.from_pretrained("diabolic6045/ELN-Llama-1B-base")
class ChatBot:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.chat_history = []
def generate_response(self, message, temperature=0.7, max_length=512):
# Format the conversation history
conversation = ""
for turn in self.chat_history:
conversation += f"User: {turn[0]}\nAssistant: {turn[1]}\n"
conversation += f"User: {message}\nAssistant:"
# Tokenize and generate
inputs = self.tokenizer(conversation, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = self.model.generate(
inputs["input_ids"],
max_length=max_length,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1,
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("Assistant:")[-1].strip()
# Update chat history
self.chat_history.append((message, response))
return response, self.chat_history
def clear_history(self):
self.chat_history = []
return [], []
# Initialize chatbot
chatbot = ChatBot(model, tokenizer)
# Example conversations
examples = [
["Hello! How are you today?"],
["Can you explain what machine learning is?"],
["Write a short poem about artificial intelligence."],
]
# Create the Gradio interface
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("# LLaMA Chatbot")
gr.Markdown("Chat with the ELN-Llama-1B model. Try asking questions or having a conversation!")
with gr.Row():
with gr.Column(scale=4):
chatbot_component = gr.Chatbot(
label="Chat History",
height=400
)
message = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
lines=2
)
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values make output more random"
)
max_length = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Max Length",
info="Maximum length of generated response"
)
clear = gr.Button("Clear Conversation")
gr.Examples(
examples=examples,
inputs=message,
label="Example prompts"
)
# Handle interactions
message.submit(
chatbot.generate_response,
inputs=[message, temperature, max_length],
outputs=[chatbot_component]
)
clear.click(
chatbot.clear_history,
outputs=[chatbot_component, message]
)
# Launch the interface
if __name__ == "__main__":
demo.launch(share=True)