Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Initialize model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("diabolic6045/ELN-Llama-1B-base") | |
model = AutoModelForCausalLM.from_pretrained("diabolic6045/ELN-Llama-1B-base") | |
class ChatBot: | |
def __init__(self, model, tokenizer): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.chat_history = [] | |
def generate_response(self, message, temperature=0.7, max_length=512): | |
# Format the conversation history | |
conversation = "" | |
for turn in self.chat_history: | |
conversation += f"User: {turn[0]}\nAssistant: {turn[1]}\n" | |
conversation += f"User: {message}\nAssistant:" | |
# Tokenize and generate | |
inputs = self.tokenizer(conversation, return_tensors="pt", truncation=True) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
inputs["input_ids"], | |
max_length=max_length, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id, | |
num_return_sequences=1, | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.split("Assistant:")[-1].strip() | |
# Update chat history | |
self.chat_history.append((message, response)) | |
return response, self.chat_history | |
def clear_history(self): | |
self.chat_history = [] | |
return [], [] | |
# Initialize chatbot | |
chatbot = ChatBot(model, tokenizer) | |
# Example conversations | |
examples = [ | |
["Hello! How are you today?"], | |
["Can you explain what machine learning is?"], | |
["Write a short poem about artificial intelligence."], | |
] | |
# Create the Gradio interface | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
gr.Markdown("# LLaMA Chatbot") | |
gr.Markdown("Chat with the ELN-Llama-1B model. Try asking questions or having a conversation!") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot_component = gr.Chatbot( | |
label="Chat History", | |
height=400 | |
) | |
message = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
with gr.Column(scale=1): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Higher values make output more random" | |
) | |
max_length = gr.Slider( | |
minimum=64, | |
maximum=1024, | |
value=512, | |
step=64, | |
label="Max Length", | |
info="Maximum length of generated response" | |
) | |
clear = gr.Button("Clear Conversation") | |
gr.Examples( | |
examples=examples, | |
inputs=message, | |
label="Example prompts" | |
) | |
# Handle interactions | |
message.submit( | |
chatbot.generate_response, | |
inputs=[message, temperature, max_length], | |
outputs=[chatbot_component] | |
) | |
clear.click( | |
chatbot.clear_history, | |
outputs=[chatbot_component, message] | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
demo.launch(share=True) |