Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from datetime import datetime | |
# Initialize session state for chat history | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B-SFT") | |
model = AutoModelForCausalLM.from_pretrained("amd/AMD-OLMo-1B-SFT") | |
if torch.cuda.is_available(): | |
model = model.to("cuda") | |
return model, tokenizer | |
def generate_response(prompt, model, tokenizer, history): | |
# Format conversation history with the template | |
bos = tokenizer.eos_token | |
conversation = "" | |
for msg in history: | |
if msg["role"] == "user": | |
conversation += f"<|user|>\n{msg['content']}\n" | |
else: | |
conversation += f"<|assistant|>\n{msg['content']}\n" | |
template = bos + conversation + f"<|user|>\n{prompt}\n<|assistant|>\n" | |
inputs = tokenizer([template], return_tensors='pt', return_token_type_ids=False) | |
if torch.cuda.is_available(): | |
inputs = inputs.to("cuda") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=1000, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.7 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's last response | |
response = response.split("<|assistant|>\n")[-1].strip() | |
return response | |
def main(): | |
st.set_page_config(page_title="AMD-OLMo Chatbot", layout="wide") | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.stTab { | |
font-size: 20px; | |
} | |
.model-info { | |
background-color: #f0f2f6; | |
padding: 20px; | |
border-radius: 10px; | |
} | |
.chat-message { | |
padding: 10px; | |
border-radius: 10px; | |
margin: 5px 0; | |
} | |
.user-message { | |
background-color: #e6f3ff; | |
} | |
.assistant-message { | |
background-color: #f0f2f6; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Create tabs | |
tab1, tab2 = st.tabs(["Model Information", "Chat Interface"]) | |
with tab1: | |
st.title("AMD-OLMo-1B-SFT Model Information") | |
st.markdown(""" | |
## Model Overview | |
AMD-OLMo-1B-SFT is a state-of-the-art language model developed by AMD[1][2]. Key features include: | |
### Architecture | |
- **Base Model**: 1.2B parameters | |
- **Layers**: 16 | |
- **Attention Heads**: 16 | |
- **Hidden Size**: 2048 | |
- **Context Length**: 2048 | |
- **Vocabulary Size**: 50,280 | |
### Training Details | |
- Pre-trained on 1.3 trillion tokens from Dolma v1.7 | |
- Supervised fine-tuned (SFT) in two phases: | |
1. Tulu V2 dataset | |
2. OpenHermes-2.5, WebInstructSub, and Code-Feedback datasets | |
### Capabilities | |
- General text generation | |
- Question answering | |
- Code understanding | |
- Reasoning tasks | |
- Instruction following | |
### Hardware Requirements | |
- Optimized for AMD Instinct™ MI250 GPUs | |
- Training performed on 16 nodes with 4 GPUs each | |
""") | |
with tab2: | |
st.title("Chat with AMD-OLMo") | |
# Load model | |
try: | |
model, tokenizer = load_model() | |
st.success("Model loaded successfully! You can start chatting.") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# Chat interface | |
st.markdown("### Chat History") | |
chat_container = st.container() | |
with chat_container: | |
for message in st.session_state.messages: | |
div_class = "user-message" if message["role"] == "user" else "assistant-message" | |
st.markdown(f""" | |
<div class="chat-message {div_class}"> | |
<b>{message["role"].title()}:</b> {message["content"]} | |
</div> | |
""", unsafe_allow_html=True) | |
# User input | |
with st.container(): | |
user_input = st.text_area("Your message:", key="user_input", height=100) | |
col1, col2, col3 = st.columns([1, 1, 4]) | |
with col1: | |
if st.button("Send"): | |
if user_input.strip(): | |
# Add user message to history | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Generate response | |
with st.spinner("Thinking..."): | |
response = generate_response(user_input, model, tokenizer, st.session_state.messages) | |
# Add assistant response to history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Clear input | |
st.session_state.user_input = "" | |
st.experimental_rerun() | |
with col2: | |
if st.button("Clear History"): | |
st.session_state.messages = [] | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() |