conormack's picture
Initial verification chat setup
a6b6e00
raw
history blame
2.33 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
def init_page():
st.set_page_config(
page_title="Verification Chat",
page_icon="πŸ”",
layout="centered"
)
st.title("AI Chat with Source Verification")
# Initialize session state
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": "Hello! How can I help you today?"}
]
if "model_name" not in st.session_state:
st.session_state.model_name = "facebook/opt-350m"
def load_model():
# Add caching to prevent reloading model
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(st.session_state.model_name)
model = AutoModelForCausalLM.from_pretrained(st.session_state.model_name)
return tokenizer, model
return get_model()
def get_response(prompt, tokenizer, model):
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_length=200,
num_return_sequences=1,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def display_messages():
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
def main():
init_page()
tokenizer, model = load_model()
# Display chat messages
display_messages()
# Chat input
if prompt := st.chat_input("What's on your mind?"):
# Add user message
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = get_response(prompt, tokenizer, model)
st.write(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()