DexterSptizu commited on
Commit
6ded5b8
1 Parent(s): 02a6b70

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ from datetime import datetime
5
+
6
+ # Initialize session state for chat history
7
+ if 'messages' not in st.session_state:
8
+ st.session_state.messages = []
9
+
10
+ @st.cache_resource
11
+ def load_model():
12
+ tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B-SFT")
13
+ model = AutoModelForCausalLM.from_pretrained("amd/AMD-OLMo-1B-SFT")
14
+ if torch.cuda.is_available():
15
+ model = model.to("cuda")
16
+ return model, tokenizer
17
+
18
+ def generate_response(prompt, model, tokenizer, history):
19
+ # Format conversation history with the template
20
+ bos = tokenizer.eos_token
21
+ conversation = ""
22
+ for msg in history:
23
+ if msg["role"] == "user":
24
+ conversation += f"<|user|>\n{msg['content']}\n"
25
+ else:
26
+ conversation += f"<|assistant|>\n{msg['content']}\n"
27
+
28
+ template = bos + conversation + f"<|user|>\n{prompt}\n<|assistant|>\n"
29
+
30
+ inputs = tokenizer([template], return_tensors='pt', return_token_type_ids=False)
31
+ if torch.cuda.is_available():
32
+ inputs = inputs.to("cuda")
33
+
34
+ outputs = model.generate(
35
+ **inputs,
36
+ max_new_tokens=1000,
37
+ do_sample=True,
38
+ top_k=50,
39
+ top_p=0.95,
40
+ temperature=0.7
41
+ )
42
+
43
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+ # Extract only the assistant's last response
45
+ response = response.split("<|assistant|>\n")[-1].strip()
46
+ return response
47
+
48
+ def main():
49
+ st.set_page_config(page_title="AMD-OLMo Chatbot", layout="wide")
50
+
51
+ # Custom CSS
52
+ st.markdown("""
53
+ <style>
54
+ .stTab {
55
+ font-size: 20px;
56
+ }
57
+ .model-info {
58
+ background-color: #f0f2f6;
59
+ padding: 20px;
60
+ border-radius: 10px;
61
+ }
62
+ .chat-message {
63
+ padding: 10px;
64
+ border-radius: 10px;
65
+ margin: 5px 0;
66
+ }
67
+ .user-message {
68
+ background-color: #e6f3ff;
69
+ }
70
+ .assistant-message {
71
+ background-color: #f0f2f6;
72
+ }
73
+ </style>
74
+ """, unsafe_allow_html=True)
75
+
76
+ # Create tabs
77
+ tab1, tab2 = st.tabs(["Model Information", "Chat Interface"])
78
+
79
+ with tab1:
80
+ st.title("AMD-OLMo-1B-SFT Model Information")
81
+
82
+ st.markdown("""
83
+ ## Model Overview
84
+ AMD-OLMo-1B-SFT is a state-of-the-art language model developed by AMD[1][2]. Key features include:
85
+
86
+ ### Architecture
87
+ - **Base Model**: 1.2B parameters
88
+ - **Layers**: 16
89
+ - **Attention Heads**: 16
90
+ - **Hidden Size**: 2048
91
+ - **Context Length**: 2048
92
+ - **Vocabulary Size**: 50,280
93
+
94
+ ### Training Details
95
+ - Pre-trained on 1.3 trillion tokens from Dolma v1.7
96
+ - Supervised fine-tuned (SFT) in two phases:
97
+ 1. Tulu V2 dataset
98
+ 2. OpenHermes-2.5, WebInstructSub, and Code-Feedback datasets
99
+
100
+ ### Capabilities
101
+ - General text generation
102
+ - Question answering
103
+ - Code understanding
104
+ - Reasoning tasks
105
+ - Instruction following
106
+
107
+ ### Hardware Requirements
108
+ - Optimized for AMD Instinct™ MI250 GPUs
109
+ - Training performed on 16 nodes with 4 GPUs each
110
+ """)
111
+
112
+ with tab2:
113
+ st.title("Chat with AMD-OLMo")
114
+
115
+ # Load model
116
+ try:
117
+ model, tokenizer = load_model()
118
+ st.success("Model loaded successfully! You can start chatting.")
119
+ except Exception as e:
120
+ st.error(f"Error loading model: {str(e)}")
121
+ return
122
+
123
+ # Chat interface
124
+ st.markdown("### Chat History")
125
+ chat_container = st.container()
126
+
127
+ with chat_container:
128
+ for message in st.session_state.messages:
129
+ div_class = "user-message" if message["role"] == "user" else "assistant-message"
130
+ st.markdown(f"""
131
+ <div class="chat-message {div_class}">
132
+ <b>{message["role"].title()}:</b> {message["content"]}
133
+ </div>
134
+ """, unsafe_allow_html=True)
135
+
136
+ # User input
137
+ with st.container():
138
+ user_input = st.text_area("Your message:", key="user_input", height=100)
139
+ col1, col2, col3 = st.columns([1, 1, 4])
140
+
141
+ with col1:
142
+ if st.button("Send"):
143
+ if user_input.strip():
144
+ # Add user message to history
145
+ st.session_state.messages.append({"role": "user", "content": user_input})
146
+
147
+ # Generate response
148
+ with st.spinner("Thinking..."):
149
+ response = generate_response(user_input, model, tokenizer, st.session_state.messages)
150
+
151
+ # Add assistant response to history
152
+ st.session_state.messages.append({"role": "assistant", "content": response})
153
+
154
+ # Clear input
155
+ st.session_state.user_input = ""
156
+ st.experimental_rerun()
157
+
158
+ with col2:
159
+ if st.button("Clear History"):
160
+ st.session_state.messages = []
161
+ st.experimental_rerun()
162
+
163
+ if __name__ == "__main__":
164
+ main()