import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch from datetime import datetime # Initialize session state variables if 'messages' not in st.session_state: st.session_state.messages = [] if "user_input_widget" not in st.session_state: st.session_state.user_input_widget = "" @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B-SFT") model = AutoModelForCausalLM.from_pretrained("amd/AMD-OLMo-1B-SFT") if torch.cuda.is_available(): model = model.to("cuda") return model, tokenizer def generate_response(prompt, model, tokenizer, history): # Format conversation history with the template bos = tokenizer.eos_token conversation = "" for msg in history: if msg["role"] == "user": conversation += f"<|user|>\n{msg['content']}\n" else: conversation += f"<|assistant|>\n{msg['content']}\n" template = bos + conversation + f"<|user|>\n{prompt}\n<|assistant|>\n" inputs = tokenizer([template], return_tensors='pt', return_token_type_ids=False) if torch.cuda.is_available(): inputs = inputs.to("cuda") outputs = model.generate( **inputs, max_new_tokens=1000, do_sample=True, top_k=50, top_p=0.95, temperature=0.7 ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's last response response = response.split("<|assistant|>\n")[-1].strip() return response def main(): st.set_page_config( page_title="AMD-OLMo Chatbot", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) # Create tabs tab1, tab2 = st.tabs(["Model Information", "Chat Interface"]) with tab1: st.title("AMD-OLMo-1B-SFT Model Information") with st.container(): st.markdown("""

Model Overview

AMD-OLMo-1B-SFT is a state-of-the-art language model developed by AMD. This model represents a significant advancement in AMD's AI capabilities.

Architecture Specifications

| Component | Specification | |-----------|---------------| | Parameters | 1.2B | | Layers | 16 | | Attention Heads | 16 | | Hidden Size | 2048 | | Context Length | 2048 | | Vocabulary Size | 50,280 |

Training Details

- Pre-trained on 1.3 trillion tokens from Dolma v1.7 - Two-phase supervised fine-tuning (SFT): 1. Tulu V2 dataset 2. OpenHermes-2.5, WebInstructSub, and Code-Feedback datasets

Key Capabilities

- Natural language understanding and generation - Context-aware responses - Code understanding and generation - Complex reasoning tasks - Instruction following - Multi-turn conversations

Hardware Optimization

- Optimized for AMD Instinctâ„¢ MI250 GPUs - Distributed training across 16 nodes with 4 GPUs each - Efficient inference on consumer hardware
""", unsafe_allow_html=True) with tab2: st.title("Chat with AMD-OLMo") # Load model try: model, tokenizer = load_model() st.success("Model loaded successfully! You can start chatting.") except Exception as e: st.error(f"Error loading model: {str(e)}") return # Chat interface st.markdown("### Chat History") chat_container = st.container() with chat_container: for message in st.session_state.messages: div_class = "user-message" if message["role"] == "user" else "assistant-message" timestamp = message.get("timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) st.markdown(f"""
{message["role"].title()}: {message["content"]}
{timestamp}
""", unsafe_allow_html=True) # User input section with st.container(): user_input = st.text_area( "Your message:", key="user_input_widget", height=100, placeholder="Type your message here..." ) col1, col2, col3 = st.columns([1, 1, 4]) with col1: if st.button("Send", use_container_width=True): if user_input.strip(): # Add user message to history with timestamp st.session_state.messages.append({ "role": "user", "content": user_input, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") }) # Generate response with st.spinner("Generating response..."): response = generate_response(user_input, model, tokenizer, st.session_state.messages) # Add assistant response to history with timestamp st.session_state.messages.append({ "role": "assistant", "content": response, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") }) # Clear input st.session_state.user_input_widget = "" st.experimental_rerun() with col2: if st.button("Clear History", use_container_width=True): st.session_state.messages = [] st.session_state.user_input_widget = "" st.experimental_rerun() if __name__ == "__main__": main()