TestOneLlama / app.py
AIdeaText's picture
Update app.py
fa0a856 verified
raw
history blame
4.36 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict
import time
class LlamaDemo:
def __init__(self):
# Using TinyLlama, which is open source and doesn't require authentication
self.model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Initialize in lazy loading fashion
self._model = None
self._tokenizer = None
@property
def model(self):
if self._model is None:
self._model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
return self._model
@property
def tokenizer(self):
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
return self._tokenizer
def generate_response(self, prompt: str, max_length: int = 512) -> str:
# Format the prompt according to TinyLlama's chat template
chat_prompt = f"<|system|>You are a helpful AI assistant.</s><|user|>{prompt}</s><|assistant|>"
inputs = self.tokenizer(chat_prompt, return_tensors="pt").to(self.model.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the response
response = response.split("<|assistant|>")[-1].strip()
return response
def main():
st.set_page_config(
page_title="Open Source Llama Demo",
page_icon="πŸ¦™",
layout="wide"
)
st.title("πŸ¦™ Open Source Llama Demo")
# Initialize session state
if 'llama' not in st.session_state:
with st.spinner("Loading model... This might take a few minutes..."):
st.session_state.llama = LlamaDemo()
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Chat interface
with st.container():
# Display chat history
for message in st.session_state.chat_history:
role = message["role"]
content = message["content"]
with st.chat_message(role):
st.write(content)
# Input for new message
if prompt := st.chat_input("What would you like to discuss?"):
# Add user message to chat history
st.session_state.chat_history.append({
"role": "user",
"content": prompt
})
with st.chat_message("user"):
st.write(prompt)
# Show assistant response
with st.chat_message("assistant"):
message_placeholder = st.empty()
with st.spinner("Thinking..."):
response = st.session_state.llama.generate_response(prompt)
message_placeholder.write(response)
# Add assistant response to chat history
st.session_state.chat_history.append({
"role": "assistant",
"content": response
})
# Sidebar with settings and info
with st.sidebar:
st.header("Settings")
max_length = st.slider("Maximum response length", 64, 1024, 512)
st.markdown("---")
st.markdown("""
### About
This demo uses TinyLlama, an open source language model that's smaller but
still capable. It's perfect for demonstrations and testing.
The model is loaded locally and doesn't require any authentication or API keys.
""")
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()
if __name__ == "__main__":
main()