conormack commited on
Commit
f3dd217
·
verified ·
1 Parent(s): 5d615e8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ import torch
4
+ import time
5
+ from typing import List, Dict
6
+ import functools
7
+ import signal
8
+
9
+ class TimeoutError(Exception):
10
+ pass
11
+
12
+ def timeout(seconds):
13
+ def decorator(func):
14
+ @functools.wraps(func)
15
+ def wrapper(*args, **kwargs):
16
+ def handler(signum, frame):
17
+ raise TimeoutError(f"Function call timed out after {seconds} seconds")
18
+
19
+ # Set the timeout handler
20
+ signal.signal(signal.SIGALRM, handler)
21
+ signal.alarm(seconds)
22
+
23
+ try:
24
+ result = func(*args, **kwargs)
25
+ finally:
26
+ # Disable the alarm
27
+ signal.alarm(0)
28
+ return result
29
+ return wrapper
30
+ return decorator
31
+
32
+ class SourceVerifier:
33
+ def __init__(self):
34
+ self.sources: List[Dict] = []
35
+
36
+ def add_source(self, text: str, metadata: Dict) -> None:
37
+ self.sources.append({"content": text, "metadata": metadata})
38
+
39
+ def verify_statement(self, statement: str) -> Dict:
40
+ matches = []
41
+ for source in self.sources:
42
+ if any(word.lower() in source["content"].lower()
43
+ for word in statement.split()):
44
+ matches.append(source)
45
+
46
+ return {
47
+ "verified": len(matches) > 0,
48
+ "matches": matches,
49
+ "confidence": len(matches) / len(self.sources) if self.sources else 0
50
+ }
51
+
52
+ @st.cache_resource(show_spinner=False)
53
+ def load_pipeline():
54
+ try:
55
+ return pipeline(
56
+ "text-generation",
57
+ model="sshleifer/tiny-gpt2", # Tiny 2M parameter model
58
+ device="cpu", # Force CPU usage
59
+ model_kwargs={"low_memory": True}
60
+ )
61
+ except Exception as e:
62
+ st.error(f"Failed to load model: {str(e)}")
63
+ return None
64
+
65
+ @timeout(10) # 10 second timeout
66
+ def generate_response(generator, prompt: str) -> str:
67
+ try:
68
+ result = generator(
69
+ prompt,
70
+ max_length=50, # Short response
71
+ num_return_sequences=1,
72
+ temperature=0.7,
73
+ do_sample=True,
74
+ )
75
+ return result[0]['generated_text']
76
+ except TimeoutError:
77
+ return "Response generation timed out. Please try again."
78
+ except Exception as e:
79
+ return f"Error generating response: {str(e)}"
80
+
81
+ def init_page():
82
+ st.set_page_config(
83
+ page_title="Quick Chat Demo",
84
+ page_icon="💬",
85
+ layout="centered"
86
+ )
87
+ st.title("Quick Chat Demo")
88
+
89
+ if "messages" not in st.session_state:
90
+ st.session_state.messages = [
91
+ {"role": "assistant", "content": "Hi! I'm a simple chat demo. How can I help?"}
92
+ ]
93
+
94
+ if "verifier" not in st.session_state:
95
+ st.session_state.verifier = SourceVerifier()
96
+
97
+ def handle_file_upload():
98
+ uploaded_file = st.file_uploader("Upload source document", type=["txt", "md", "json"])
99
+ if uploaded_file:
100
+ try:
101
+ content = uploaded_file.read().decode()
102
+ st.session_state.verifier.add_source(
103
+ content,
104
+ {"filename": uploaded_file.name, "type": uploaded_file.type}
105
+ )
106
+ st.success(f"Added source: {uploaded_file.name}")
107
+ except Exception as e:
108
+ st.error(f"Error processing file: {str(e)}")
109
+
110
+ def main():
111
+ init_page()
112
+
113
+ # Load the model with a progress bar
114
+ with st.spinner("Loading (should take < 5 seconds)..."):
115
+ generator = load_pipeline()
116
+ if generator is None:
117
+ st.error("Failed to initialize chat. Please refresh the page.")
118
+ return
119
+
120
+ # Sidebar for document upload
121
+ with st.sidebar:
122
+ st.header("Sources")
123
+ handle_file_upload()
124
+
125
+ # Display existing messages
126
+ for message in st.session_state.messages:
127
+ with st.chat_message(message["role"]):
128
+ st.write(message["content"])
129
+
130
+ # Chat input
131
+ if prompt := st.chat_input("Say something"):
132
+ # Add user message
133
+ st.session_state.messages.append({"role": "user", "content": prompt})
134
+ with st.chat_message("user"):
135
+ st.write(prompt)
136
+
137
+ # Generate response with timeout
138
+ with st.chat_message("assistant"):
139
+ with st.spinner("Responding..."):
140
+ response = generate_response(generator, prompt)
141
+ verification = st.session_state.verifier.verify_statement(response)
142
+
143
+ st.write(response)
144
+ if verification["verified"]:
145
+ with st.expander("Sources"):
146
+ st.json(verification)
147
+
148
+ st.session_state.messages.append({
149
+ "role": "assistant",
150
+ "content": response,
151
+ "verification": verification
152
+ })
153
+
154
+ if __name__ == "__main__":
155
+ main()