import streamlit as st from transformers import pipeline import torch import time from typing import List, Dict import functools import signal class TimeoutError(Exception): pass def timeout(seconds): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): def handler(signum, frame): raise TimeoutError(f"Function call timed out after {seconds} seconds") # Set the timeout handler signal.signal(signal.SIGALRM, handler) signal.alarm(seconds) try: result = func(*args, **kwargs) finally: # Disable the alarm signal.alarm(0) return result return wrapper return decorator class SourceVerifier: def __init__(self): self.sources: List[Dict] = [] def add_source(self, text: str, metadata: Dict) -> None: self.sources.append({"content": text, "metadata": metadata}) def verify_statement(self, statement: str) -> Dict: matches = [] for source in self.sources: if any(word.lower() in source["content"].lower() for word in statement.split()): matches.append(source) return { "verified": len(matches) > 0, "matches": matches, "confidence": len(matches) / len(self.sources) if self.sources else 0 } @st.cache_resource(show_spinner=False) def load_pipeline(): try: return pipeline( "text-generation", model="sshleifer/tiny-gpt2", # Tiny 2M parameter model device="cpu", # Force CPU usage model_kwargs={"low_memory": True} ) except Exception as e: st.error(f"Failed to load model: {str(e)}") return None @timeout(10) # 10 second timeout def generate_response(generator, prompt: str) -> str: try: result = generator( prompt, max_length=50, # Short response num_return_sequences=1, temperature=0.7, do_sample=True, ) return result[0]['generated_text'] except TimeoutError: return "Response generation timed out. Please try again." except Exception as e: return f"Error generating response: {str(e)}" def init_page(): st.set_page_config( page_title="Quick Chat Demo", page_icon="💬", layout="centered" ) st.title("Quick Chat Demo") if "messages" not in st.session_state: st.session_state.messages = [ {"role": "assistant", "content": "Hi! I'm a simple chat demo. How can I help?"} ] if "verifier" not in st.session_state: st.session_state.verifier = SourceVerifier() def handle_file_upload(): uploaded_file = st.file_uploader("Upload source document", type=["txt", "md", "json"]) if uploaded_file: try: content = uploaded_file.read().decode() st.session_state.verifier.add_source( content, {"filename": uploaded_file.name, "type": uploaded_file.type} ) st.success(f"Added source: {uploaded_file.name}") except Exception as e: st.error(f"Error processing file: {str(e)}") def main(): init_page() # Load the model with a progress bar with st.spinner("Loading (should take < 5 seconds)..."): generator = load_pipeline() if generator is None: st.error("Failed to initialize chat. Please refresh the page.") return # Sidebar for document upload with st.sidebar: st.header("Sources") handle_file_upload() # Display existing messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) # Chat input if prompt := st.chat_input("Say something"): # Add user message st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.write(prompt) # Generate response with timeout with st.chat_message("assistant"): with st.spinner("Responding..."): response = generate_response(generator, prompt) verification = st.session_state.verifier.verify_statement(response) st.write(response) if verification["verified"]: with st.expander("Sources"): st.json(verification) st.session_state.messages.append({ "role": "assistant", "content": response, "verification": verification }) if __name__ == "__main__": main()