Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import streamlit as st | |
# Must be first Streamlit command | |
st.set_page_config( | |
page_title="ARIA Research Assistant", | |
page_icon="🔬", | |
layout="wide", | |
initial_sidebar_state="auto" | |
) | |
import anthropic | |
import openai | |
import base64 | |
import cv2 | |
import glob | |
import json | |
import os | |
import pytz | |
import random | |
import re | |
import requests | |
import time | |
import zipfile | |
import plotly.graph_objects as go | |
import streamlit.components.v1 as components | |
from datetime import datetime | |
from audio_recorder_streamlit import audio_recorder | |
from bs4 import BeautifulSoup | |
from collections import defaultdict, deque | |
from dotenv import load_dotenv | |
from gradio_client import Client | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
from PyPDF2 import PdfReader | |
from urllib.parse import quote | |
from xml.etree import ElementTree as ET | |
from openai import OpenAI | |
import extra_streamlit_components as stx | |
import asyncio | |
import edge_tts | |
# Load environment variables | |
load_dotenv() | |
# API Setup & Clients | |
openai_api_key = os.getenv('OPENAI_API_KEY', st.secrets.get('OPENAI_API_KEY', '')) | |
anthropic_key = os.getenv('ANTHROPIC_API_KEY_3', st.secrets.get('ANTHROPIC_API_KEY', '')) | |
xai_key = os.getenv('xai', '') | |
openai.api_key = openai_api_key | |
claude_client = anthropic.Anthropic(api_key=anthropic_key) | |
openai_client = OpenAI(api_key=openai.api_key, organization=os.getenv('OPENAI_ORG_ID')) | |
# Session State Management | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'messages' not in st.session_state: | |
st.session_state['messages'] = [] | |
if 'old_val' not in st.session_state: | |
st.session_state['old_val'] = None | |
if 'current_audio' not in st.session_state: | |
st.session_state['current_audio'] = None | |
# Styling | |
st.markdown(""" | |
<style> | |
.main { background: linear-gradient(to right, #1a1a1a, #2d2d2d); color: #fff; } | |
.stButton>button { | |
margin-right: 0.5rem; | |
background-color: #4CAF50; | |
color: white; | |
padding: 0.5rem 1rem; | |
border-radius: 5px; | |
border: none; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Audio Functions | |
def clean_for_speech(text: str) -> str: | |
"""Clean text for speech synthesis""" | |
text = text.replace("\n", " ") | |
text = text.replace("</s>", " ") | |
text = text.replace("#", "") | |
text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text) | |
text = re.sub(r"\s+", " ", text).strip() | |
return text | |
def get_audio_html(audio_path): | |
"""Create HTML for autoplaying audio""" | |
try: | |
with open(audio_path, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
audio_b64 = base64.b64encode(audio_bytes).decode() | |
return f''' | |
<audio controls autoplay> | |
<source src="data:audio/mpeg;base64,{audio_b64}" type="audio/mpeg"> | |
</audio> | |
<a href="data:audio/mpeg;base64,{audio_b64}" | |
download="{os.path.basename(audio_path)}"> | |
Download {os.path.basename(audio_path)} | |
</a> | |
''' | |
except Exception as e: | |
return f"Error loading audio: {str(e)}" | |
async def generate_audio(text, voice="en-US-AriaNeural"): | |
"""Generate audio using Edge TTS""" | |
if not text.strip(): | |
return None | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_file = f"response_{timestamp}.mp3" | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(output_file) | |
return output_file | |
# Core Search Function | |
def perform_ai_lookup(query, vocal_summary=True, full_audio=False): | |
"""Perform search with automatic audio generation""" | |
try: | |
client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern") | |
refs = client.predict( | |
query, | |
20, | |
"Semantic Search", | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
api_name="/update_with_rag_md" | |
)[0] | |
summary = client.predict( | |
query, | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
True, | |
api_name="/ask_llm" | |
) | |
result = f"### 🔎 Search Results\n\n{summary}\n\n### References\n\n{refs}" | |
st.markdown(result) | |
# Generate and play audio | |
if full_audio: | |
audio_file = asyncio.run(generate_audio(summary)) | |
if audio_file: | |
st.markdown(get_audio_html(audio_file), unsafe_allow_html=True) | |
return result | |
except Exception as e: | |
st.error(f"Error in search: {str(e)}") | |
return None | |
def main(): | |
st.sidebar.markdown("### Research Assistant") | |
# Voice component | |
mycomponent = components.declare_component("mycomponent", path="mycomponent") | |
val = mycomponent(my_input_value="Hello") | |
# Handle voice input | |
if val: | |
val_stripped = val.replace('\n', ' ') | |
edited_input = st.text_area("✏️ Edit Input:", value=val_stripped, height=100) | |
col1, col2 = st.columns([3,1]) | |
with col1: | |
model = st.selectbox("Model:", ["Arxiv", "GPT-4", "Claude"]) | |
with col2: | |
autorun = st.checkbox("⚙ AutoRun", value=True) | |
# Check for changes and autorun | |
input_changed = (val != st.session_state.old_val) | |
if autorun and input_changed: | |
st.session_state.old_val = val | |
if edited_input: | |
perform_ai_lookup(edited_input, vocal_summary=True, full_audio=True) | |
else: | |
if st.button("🔍 Search"): | |
perform_ai_lookup(edited_input, vocal_summary=True, full_audio=True) | |
# Manual search tab | |
st.markdown("### 🔍 Direct Search") | |
query = st.text_input("Enter search query:") | |
if query and st.button("Search"): | |
perform_ai_lookup(query, vocal_summary=True, full_audio=True) | |
if __name__ == "__main__": | |
main() |