Spaces:
Running
Running
File size: 5,439 Bytes
6ded5b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datetime import datetime
# Initialize session state for chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B-SFT")
model = AutoModelForCausalLM.from_pretrained("amd/AMD-OLMo-1B-SFT")
if torch.cuda.is_available():
model = model.to("cuda")
return model, tokenizer
def generate_response(prompt, model, tokenizer, history):
# Format conversation history with the template
bos = tokenizer.eos_token
conversation = ""
for msg in history:
if msg["role"] == "user":
conversation += f"<|user|>\n{msg['content']}\n"
else:
conversation += f"<|assistant|>\n{msg['content']}\n"
template = bos + conversation + f"<|user|>\n{prompt}\n<|assistant|>\n"
inputs = tokenizer([template], return_tensors='pt', return_token_type_ids=False)
if torch.cuda.is_available():
inputs = inputs.to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=1000,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's last response
response = response.split("<|assistant|>\n")[-1].strip()
return response
def main():
st.set_page_config(page_title="AMD-OLMo Chatbot", layout="wide")
# Custom CSS
st.markdown("""
<style>
.stTab {
font-size: 20px;
}
.model-info {
background-color: #f0f2f6;
padding: 20px;
border-radius: 10px;
}
.chat-message {
padding: 10px;
border-radius: 10px;
margin: 5px 0;
}
.user-message {
background-color: #e6f3ff;
}
.assistant-message {
background-color: #f0f2f6;
}
</style>
""", unsafe_allow_html=True)
# Create tabs
tab1, tab2 = st.tabs(["Model Information", "Chat Interface"])
with tab1:
st.title("AMD-OLMo-1B-SFT Model Information")
st.markdown("""
## Model Overview
AMD-OLMo-1B-SFT is a state-of-the-art language model developed by AMD[1][2]. Key features include:
### Architecture
- **Base Model**: 1.2B parameters
- **Layers**: 16
- **Attention Heads**: 16
- **Hidden Size**: 2048
- **Context Length**: 2048
- **Vocabulary Size**: 50,280
### Training Details
- Pre-trained on 1.3 trillion tokens from Dolma v1.7
- Supervised fine-tuned (SFT) in two phases:
1. Tulu V2 dataset
2. OpenHermes-2.5, WebInstructSub, and Code-Feedback datasets
### Capabilities
- General text generation
- Question answering
- Code understanding
- Reasoning tasks
- Instruction following
### Hardware Requirements
- Optimized for AMD Instinct™ MI250 GPUs
- Training performed on 16 nodes with 4 GPUs each
""")
with tab2:
st.title("Chat with AMD-OLMo")
# Load model
try:
model, tokenizer = load_model()
st.success("Model loaded successfully! You can start chatting.")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# Chat interface
st.markdown("### Chat History")
chat_container = st.container()
with chat_container:
for message in st.session_state.messages:
div_class = "user-message" if message["role"] == "user" else "assistant-message"
st.markdown(f"""
<div class="chat-message {div_class}">
<b>{message["role"].title()}:</b> {message["content"]}
</div>
""", unsafe_allow_html=True)
# User input
with st.container():
user_input = st.text_area("Your message:", key="user_input", height=100)
col1, col2, col3 = st.columns([1, 1, 4])
with col1:
if st.button("Send"):
if user_input.strip():
# Add user message to history
st.session_state.messages.append({"role": "user", "content": user_input})
# Generate response
with st.spinner("Thinking..."):
response = generate_response(user_input, model, tokenizer, st.session_state.messages)
# Add assistant response to history
st.session_state.messages.append({"role": "assistant", "content": response})
# Clear input
st.session_state.user_input = ""
st.experimental_rerun()
with col2:
if st.button("Clear History"):
st.session_state.messages = []
st.experimental_rerun()
if __name__ == "__main__":
main() |