conormack commited on
Commit
a6b6e00
·
1 Parent(s): 2b10037

Initial verification chat setup

Browse files
Files changed (2) hide show
  1. app.py +76 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ def init_page():
6
+ st.set_page_config(
7
+ page_title="Verification Chat",
8
+ page_icon="🔍",
9
+ layout="centered"
10
+ )
11
+ st.title("AI Chat with Source Verification")
12
+
13
+ # Initialize session state
14
+ if "messages" not in st.session_state:
15
+ st.session_state.messages = [
16
+ {"role": "assistant", "content": "Hello! How can I help you today?"}
17
+ ]
18
+
19
+ if "model_name" not in st.session_state:
20
+ st.session_state.model_name = "facebook/opt-350m"
21
+
22
+ def load_model():
23
+ # Add caching to prevent reloading model
24
+ @st.cache_resource
25
+ def get_model():
26
+ tokenizer = AutoTokenizer.from_pretrained(st.session_state.model_name)
27
+ model = AutoModelForCausalLM.from_pretrained(st.session_state.model_name)
28
+ return tokenizer, model
29
+
30
+ return get_model()
31
+
32
+ def get_response(prompt, tokenizer, model):
33
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
34
+
35
+ with torch.no_grad():
36
+ outputs = model.generate(
37
+ inputs["input_ids"],
38
+ max_length=200,
39
+ num_return_sequences=1,
40
+ temperature=0.7,
41
+ pad_token_id=tokenizer.eos_token_id
42
+ )
43
+
44
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return response
46
+
47
+ def display_messages():
48
+ for message in st.session_state.messages:
49
+ with st.chat_message(message["role"]):
50
+ st.write(message["content"])
51
+
52
+ def main():
53
+ init_page()
54
+ tokenizer, model = load_model()
55
+
56
+ # Display chat messages
57
+ display_messages()
58
+
59
+ # Chat input
60
+ if prompt := st.chat_input("What's on your mind?"):
61
+ # Add user message
62
+ st.session_state.messages.append({"role": "user", "content": prompt})
63
+
64
+ # Display user message
65
+ with st.chat_message("user"):
66
+ st.write(prompt)
67
+
68
+ # Generate response
69
+ with st.chat_message("assistant"):
70
+ with st.spinner("Thinking..."):
71
+ response = get_response(prompt, tokenizer, model)
72
+ st.write(response)
73
+ st.session_state.messages.append({"role": "assistant", "content": response})
74
+
75
+ if __name__ == "__main__":
76
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers