|
import gradio as gr |
|
import weaviate |
|
import os |
|
from openai import AsyncOpenAI |
|
from dotenv import load_dotenv |
|
import asyncio |
|
from functools import wraps |
|
import logging |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv() |
|
|
|
|
|
openai_client = AsyncOpenAI(api_key=os.getenv('OPENAI_API_KEY')) |
|
|
|
|
|
client = None |
|
|
|
|
|
COLLECTION_NAME = os.getenv('WEAVIATE_COLLECTION_NAME') |
|
|
|
|
|
|
|
connection_status = {"status": "Disconnected", "color": "red"} |
|
|
|
|
|
async def initialize_weaviate_client(max_retries=3, retry_delay=5): |
|
global client, connection_status |
|
retries = 0 |
|
while retries < max_retries: |
|
connection_status = {"status": "Connecting...", "color": "orange"} |
|
try: |
|
logger.info(f"Attempting to connect to Weaviate (Attempt {retries + 1}/{max_retries})") |
|
client = weaviate.Client( |
|
url=os.getenv('WCS_URL'), |
|
auth_client_secret=weaviate.auth.AuthApiKey(os.getenv('WCS_API_KEY')), |
|
additional_headers={ |
|
"X-OpenAI-Api-Key": os.getenv('OPENAI_API_KEY') |
|
} |
|
) |
|
|
|
await asyncio.to_thread(client.schema.get) |
|
connection_status = {"status": "Connected", "color": "green"} |
|
logger.info("Successfully connected to Weaviate") |
|
return connection_status |
|
except Exception as e: |
|
logger.error(f"Error connecting to Weaviate: {str(e)}") |
|
connection_status = {"status": f"Error: {str(e)}", "color": "red"} |
|
retries += 1 |
|
if retries < max_retries: |
|
logger.info(f"Retrying in {retry_delay} seconds...") |
|
await asyncio.sleep(retry_delay) |
|
else: |
|
logger.error("Max retries reached. Could not connect to Weaviate.") |
|
return connection_status |
|
|
|
|
|
|
|
def async_lru_cache(maxsize=128): |
|
cache = {} |
|
|
|
def decorator(func): |
|
@wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
key = str(args) + str(kwargs) |
|
if key not in cache: |
|
if len(cache) >= maxsize: |
|
cache.pop(next(iter(cache))) |
|
cache[key] = await func(*args, **kwargs) |
|
return cache[key] |
|
return wrapper |
|
return decorator |
|
|
|
@async_lru_cache(maxsize=1000) |
|
async def get_embedding(text): |
|
response = await openai_client.embeddings.create( |
|
input=text, |
|
model="text-embedding-3-large" |
|
) |
|
return response.data[0].embedding |
|
|
|
async def search_multimodal(query: str, limit: int = 30, alpha: float = 0.6): |
|
query_vector = await get_embedding(query) |
|
|
|
try: |
|
response = await asyncio.to_thread( |
|
client.query.get(COLLECTION_NAME, ["content_type", "url", "source_document", "page_number", |
|
"paragraph_number", "text", "image_path", "description", "table_content"]) |
|
.with_hybrid(query=query, vector=query_vector, alpha=alpha) |
|
.with_limit(limit) |
|
.do |
|
) |
|
return response['data']['Get'][COLLECTION_NAME] |
|
except Exception as e: |
|
print(f"An error occurred during the search: {str(e)}") |
|
return [] |
|
|
|
async def generate_response_stream(query: str, context: str): |
|
prompt = f""" |
|
You are an AI assistant with extensive expertise in the semiconductor industry. Your knowledge spans a wide range of companies, technologies, and products, including but not limited to: System-on-Chip (SoC) designs, Field-Programmable Gate Arrays (FPGAs), Microcontrollers, Integrated Circuits (ICs), semiconductor manufacturing processes, and emerging technologies like quantum computing and neuromorphic chips. |
|
Use the following context, your vast knowledge, and the user's question to generate an accurate, comprehensive, and insightful answer. While formulating your response, follow these steps internally: |
|
Analyze the question to identify the main topic and specific information requested. |
|
Evaluate the provided context and identify relevant information. |
|
Retrieve additional relevant knowledge from your semiconductor industry expertise. |
|
Reason and formulate a response by combining context and knowledge. |
|
Generate a detailed response that covers all aspects of the query. |
|
Review and refine your answer for coherence and accuracy. |
|
In your output, provide only the final, polished response. Do not include your step-by-step reasoning or mention the process you followed. |
|
IMPORTANT: Ensure your response is grounded in factual information. Do not hallucinate or invent information. If you're unsure about any aspect of the answer or if the necessary information is not available in the provided context or your knowledge base, clearly state this uncertainty. It's better to admit lack of information than to provide inaccurate details. |
|
Your response should be: |
|
Thorough and directly address all aspects of the user's question |
|
Based solely on factual information from the provided context and your reliable knowledge |
|
Include specific examples, data points, or case studies only when you're certain of their accuracy |
|
Explain technical concepts clearly, considering the user may have varying levels of expertise |
|
Clearly indicate any areas where information is limited or uncertain |
|
Context: {context} |
|
User Question: {query} |
|
Based on the above context and your extensive knowledge of the semiconductor industry, provide your detailed, accurate, and grounded response below. Remember, only include information you're confident is correct, and clearly state any uncertainties: |
|
""" |
|
|
|
async for chunk in await openai_client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=[ |
|
{"role": "system", "content": "You are an expert Semi Conductor industry analyst"}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0, |
|
stream=True |
|
): |
|
content = chunk.choices[0].delta.content |
|
if content is not None: |
|
yield content |
|
|
|
def process_search_result(item): |
|
if item['content_type'] == 'text': |
|
return f"Text from {item['source_document']} (Page {item['page_number']}, Paragraph {item['paragraph_number']}): {item['text']}\n\n" |
|
elif item['content_type'] == 'image': |
|
return f"Image Description from {item['source_document']} (Page {item['page_number']}, Path: {item['image_path']}): {item['description']}\n\n" |
|
elif item['content_type'] == 'table': |
|
return f"Table Description from {item['source_document']} (Page {item['page_number']}): {item['description']}\n\n" |
|
return "" |
|
|
|
async def esg_analysis_stream(user_query: str): |
|
search_results = await search_multimodal(user_query) |
|
|
|
context_parts = await asyncio.gather(*[asyncio.to_thread(process_search_result, item) for item in search_results]) |
|
context = "".join(context_parts) |
|
|
|
sources = [] |
|
for item in search_results[:5]: |
|
source = { |
|
"type": item.get("content_type", "Unknown"), |
|
"document": item.get("source_document", "N/A"), |
|
"page": item.get("page_number", "N/A"), |
|
} |
|
if item.get("content_type") == 'text': |
|
source["paragraph"] = item.get("paragraph_number", "N/A") |
|
elif item.get("content_type") == 'image': |
|
source["image_path"] = item.get("image_path", "N/A") |
|
sources.append(source) |
|
|
|
return generate_response_stream(user_query, context), sources |
|
|
|
def format_sources(sources): |
|
source_text = "## Top 5 Sources\n\n" |
|
for i, source in enumerate(sources, 1): |
|
source_text += f"### Source {i}\n" |
|
source_text += f"- **Type:** {source['type']}\n" |
|
source_text += f"- **Document:** {source['document']}\n" |
|
source_text += f"- **Page:** {source['page']}\n" |
|
if 'paragraph' in source: |
|
source_text += f"- **Paragraph:** {source['paragraph']}\n" |
|
if 'image_path' in source: |
|
source_text += f"- **Image Path:** {source['image_path']}\n" |
|
source_text += "\n" |
|
return source_text |
|
|
|
|
|
custom_css = """ |
|
#status-box { |
|
position: absolute; |
|
top: 10px; |
|
right: 10px; |
|
background-color: white; |
|
padding: 5px 10px; |
|
border-radius: 5px; |
|
box-shadow: 0 2px 5px rgba(0,0,0,0.1); |
|
z-index: 1000; |
|
display: flex; |
|
align-items: center; |
|
} |
|
#status-light { |
|
width: 10px; |
|
height: 10px; |
|
border-radius: 50%; |
|
display: inline-block; |
|
margin-right: 5px; |
|
} |
|
#status-text { |
|
font-size: 14px; |
|
font-weight: bold; |
|
} |
|
""" |
|
|
|
def get_connection_status(): |
|
status = connection_status["status"] |
|
color = connection_status["color"] |
|
return f'<div id="status-box"><div id="status-light" style="background-color: {color};"></div><span id="status-text">{status}</span></div>' |
|
|
|
async def check_connection(): |
|
global connection_status |
|
try: |
|
if client: |
|
await asyncio.to_thread(client.schema.get) |
|
return {"status": "Connected", "color": "green"} |
|
else: |
|
return {"status": "Disconnected", "color": "red"} |
|
except Exception: |
|
return {"status": "Disconnected", "color": "red"} |
|
|
|
async def update_status(): |
|
global connection_status |
|
while True: |
|
new_status = await check_connection() |
|
if new_status != connection_status: |
|
connection_status = new_status |
|
yield new_status |
|
await asyncio.sleep(5) |
|
|
|
async def gradio_interface(user_question): |
|
if connection_status["status"] != "Connected": |
|
return "Error: Database not connected. Please wait for the connection to be established.", "" |
|
|
|
response_generator, sources = await esg_analysis_stream(user_question) |
|
formatted_sources = format_sources(sources) |
|
|
|
full_response = "" |
|
async for response_chunk in response_generator: |
|
full_response += response_chunk |
|
|
|
return full_response, formatted_sources |
|
|
|
with gr.Blocks(css=custom_css) as iface: |
|
status_indicator = gr.HTML(get_connection_status()) |
|
|
|
with gr.Row(): |
|
gr.Markdown("# Semiconductor Industry Analysis") |
|
|
|
gr.Markdown("Ask questions about the semiconductor industry and get AI-powered answers with sources.") |
|
|
|
user_question = gr.Textbox(lines=2, placeholder="Enter your question about the semiconductor industry...", interactive=False) |
|
ai_response = gr.Markdown(label="AI Response") |
|
sources_output = gr.Markdown(label="Sources") |
|
|
|
submit_btn = gr.Button("Submit", interactive=False) |
|
|
|
submit_btn.click( |
|
fn=gradio_interface, |
|
inputs=user_question, |
|
outputs=[ai_response, sources_output], |
|
) |
|
|
|
|
|
def update_status_indicator(status): |
|
return get_connection_status() |
|
|
|
def update_input_state(status): |
|
is_connected = status["status"] == "Connected" |
|
return gr.update(interactive=is_connected), gr.update(interactive=is_connected) |
|
|
|
status_updater = gr.State(connection_status) |
|
|
|
iface.load( |
|
lambda: connection_status, |
|
outputs=[status_updater], |
|
every=1, |
|
) |
|
|
|
status_updater.change( |
|
fn=update_status_indicator, |
|
inputs=[status_updater], |
|
outputs=[status_indicator], |
|
) |
|
|
|
status_updater.change( |
|
fn=update_input_state, |
|
inputs=[status_updater], |
|
outputs=[user_question, submit_btn], |
|
) |
|
|
|
status_updater = gr.State(connection_status) |
|
|
|
iface.load( |
|
lambda: connection_status, |
|
outputs=[status_updater], |
|
every=1, |
|
) |
|
|
|
status_updater.change( |
|
fn=update_status_indicator, |
|
inputs=[status_updater], |
|
outputs=[status_indicator], |
|
) |
|
|
|
status_updater.change( |
|
fn=update_input_state, |
|
inputs=[status_updater], |
|
outputs=[user_question, submit_btn], |
|
) |
|
|
|
async def main(): |
|
|
|
required_env_vars = ['WCS_URL', 'WCS_API_KEY', 'OPENAI_API_KEY', 'WEAVIATE_COLLECTION_NAME'] |
|
for var in required_env_vars: |
|
if not os.getenv(var): |
|
logger.error(f"Environment variable {var} is not set!") |
|
return |
|
|
|
|
|
await initialize_weaviate_client() |
|
|
|
|
|
await iface.launch(server_name="0.0.0.0", server_port=7860, share=True) |
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |