Spaces:
Runtime error
Runtime error
import gradio as gr | |
from gradio.helpers import Progress | |
import asyncio | |
import subprocess | |
import yaml | |
import os | |
import networkx as nx | |
import plotly.graph_objects as go | |
import numpy as np | |
import plotly.io as pio | |
import lancedb | |
import random | |
import io | |
import shutil | |
import logging | |
import queue | |
import threading | |
import time | |
from collections import deque | |
import re | |
import glob | |
from datetime import datetime | |
import json | |
import requests | |
import aiohttp | |
from openai import OpenAI | |
from openai import AsyncOpenAI | |
import pyarrow.parquet as pq | |
import pandas as pd | |
import sys | |
import colorsys | |
from dotenv import load_dotenv, set_key | |
import argparse | |
import socket | |
import tiktoken | |
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey | |
from graphrag.query.indexer_adapters import ( | |
read_indexer_covariates, | |
read_indexer_entities, | |
read_indexer_relationships, | |
read_indexer_reports, | |
read_indexer_text_units, | |
) | |
from graphrag.llm.openai import create_openai_chat_llm | |
from graphrag.llm.openai.factories import create_openai_embedding_llm | |
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings | |
from graphrag.query.llm.oai.chat_openai import ChatOpenAI | |
from graphrag.llm.openai.openai_configuration import OpenAIConfiguration | |
from graphrag.llm.openai.openai_embeddings_llm import OpenAIEmbeddingsLLM | |
from graphrag.query.llm.oai.typing import OpenaiApiType | |
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext | |
from graphrag.query.structured_search.local_search.search import LocalSearch | |
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext | |
from graphrag.query.structured_search.global_search.search import GlobalSearch | |
from graphrag.vector_stores.lancedb import LanceDBVectorStore | |
import textwrap | |
# Suppress warnings | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="gradio_client.documentation") | |
load_dotenv('indexing/.env') | |
# Set default values for API-related environment variables | |
os.environ.setdefault("LLM_API_BASE", os.getenv("LLM_API_BASE")) | |
os.environ.setdefault("LLM_API_KEY", os.getenv("LLM_API_KEY")) | |
os.environ.setdefault("LLM_MODEL", os.getenv("LLM_MODEL")) | |
os.environ.setdefault("EMBEDDINGS_API_BASE", os.getenv("EMBEDDINGS_API_BASE")) | |
os.environ.setdefault("EMBEDDINGS_API_KEY", os.getenv("EMBEDDINGS_API_KEY")) | |
os.environ.setdefault("EMBEDDINGS_MODEL", os.getenv("EMBEDDINGS_MODEL")) | |
# Add the project root to the Python path | |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
sys.path.insert(0, project_root) | |
# Set up logging | |
log_queue = queue.Queue() | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
llm = None | |
text_embedder = None | |
class QueueHandler(logging.Handler): | |
def __init__(self, log_queue): | |
super().__init__() | |
self.log_queue = log_queue | |
def emit(self, record): | |
self.log_queue.put(self.format(record)) | |
queue_handler = QueueHandler(log_queue) | |
logging.getLogger().addHandler(queue_handler) | |
def initialize_models(): | |
global llm, text_embedder | |
llm_api_base = os.getenv("LLM_API_BASE") | |
llm_api_key = os.getenv("LLM_API_KEY") | |
embeddings_api_base = os.getenv("EMBEDDINGS_API_BASE") | |
embeddings_api_key = os.getenv("EMBEDDINGS_API_KEY") | |
llm_service_type = os.getenv("LLM_SERVICE_TYPE", "openai_chat").lower() # Provide a default and lower it | |
embeddings_service_type = os.getenv("EMBEDDINGS_SERVICE_TYPE", "openai").lower() # Provide a default and lower it | |
llm_model = os.getenv("LLM_MODEL") | |
embeddings_model = os.getenv("EMBEDDINGS_MODEL") | |
logging.info("Fetching models...") | |
models = fetch_models(llm_api_base, llm_api_key, llm_service_type) | |
# Use the same models list for both LLM and embeddings | |
llm_models = models | |
embeddings_models = models | |
# Initialize LLM | |
if llm_service_type == "openai_chat": | |
llm = ChatOpenAI( | |
api_key=llm_api_key, | |
api_base=f"{llm_api_base}/v1", | |
model=llm_model, | |
api_type=OpenaiApiType.OpenAI, | |
max_retries=20, | |
) | |
# Initialize OpenAI client for embeddings | |
openai_client = OpenAI( | |
api_key=embeddings_api_key or "dummy_key", | |
base_url=f"{embeddings_api_base}/v1" | |
) | |
# Initialize text embedder using OpenAIEmbeddingsLLM | |
text_embedder = OpenAIEmbeddingsLLM( | |
client=openai_client, | |
configuration={ | |
"model": embeddings_model, | |
"api_type": "open_ai", | |
"api_base": embeddings_api_base, | |
"api_key": embeddings_api_key or None, | |
"provider": embeddings_service_type | |
} | |
) | |
return llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder | |
def find_latest_output_folder(): | |
root_dir = "./indexing/output" | |
folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))] | |
if not folders: | |
raise ValueError("No output folders found") | |
# Sort folders by creation time, most recent first | |
sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True) | |
latest_folder = None | |
timestamp = None | |
for folder in sorted_folders: | |
try: | |
# Try to parse the folder name as a timestamp | |
timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S") | |
latest_folder = folder | |
break | |
except ValueError: | |
# If the folder name is not a valid timestamp, skip it | |
continue | |
if latest_folder is None: | |
raise ValueError("No valid timestamp folders found") | |
latest_path = os.path.join(root_dir, latest_folder) | |
artifacts_path = os.path.join(latest_path, "artifacts") | |
if not os.path.exists(artifacts_path): | |
raise ValueError(f"Artifacts folder not found in {latest_path}") | |
return latest_path, latest_folder | |
def initialize_data(): | |
global entity_df, relationship_df, text_unit_df, report_df, covariate_df | |
tables = { | |
"entity_df": "create_final_nodes", | |
"relationship_df": "create_final_edges", | |
"text_unit_df": "create_final_text_units", | |
"report_df": "create_final_reports", | |
"covariate_df": "create_final_covariates" | |
} | |
timestamp = None # Initialize timestamp to None | |
try: | |
latest_output_folder, timestamp = find_latest_output_folder() | |
artifacts_folder = os.path.join(latest_output_folder, "artifacts") | |
for df_name, file_prefix in tables.items(): | |
file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet") | |
matching_files = glob.glob(file_pattern) | |
if matching_files: | |
latest_file = max(matching_files, key=os.path.getctime) | |
df = pd.read_parquet(latest_file) | |
globals()[df_name] = df | |
logging.info(f"Successfully loaded {df_name} from {latest_file}") | |
else: | |
logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.") | |
globals()[df_name] = pd.DataFrame() | |
except Exception as e: | |
logging.error(f"Error initializing data: {str(e)}") | |
for df_name in tables.keys(): | |
globals()[df_name] = pd.DataFrame() | |
return timestamp | |
# Call initialize_data and store the timestamp | |
current_timestamp = initialize_data() | |
def find_available_port(start_port, max_attempts=100): | |
for port in range(start_port, start_port + max_attempts): | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
try: | |
s.bind(('', port)) | |
return port | |
except OSError: | |
continue | |
raise IOError("No free ports found") | |
def start_api_server(port): | |
subprocess.Popen([sys.executable, "api_server.py", "--port", str(port)]) | |
def wait_for_api_server(port): | |
max_retries = 30 | |
for _ in range(max_retries): | |
try: | |
response = requests.get(f"http://localhost:{port}") | |
if response.status_code == 200: | |
print(f"API server is up and running on port {port}") | |
return | |
else: | |
print(f"Unexpected response from API server: {response.status_code}") | |
except requests.ConnectionError: | |
time.sleep(1) | |
print("Failed to connect to API server") | |
def load_settings(): | |
try: | |
with open("indexing/settings.yaml", "r") as f: | |
return yaml.safe_load(f) or {} | |
except FileNotFoundError: | |
return {} | |
def update_setting(key, value): | |
settings = load_settings() | |
try: | |
settings[key] = json.loads(value) | |
except json.JSONDecodeError: | |
settings[key] = value | |
try: | |
with open("indexing/settings.yaml", "w") as f: | |
yaml.dump(settings, f, default_flow_style=False) | |
return f"Setting '{key}' updated successfully" | |
except Exception as e: | |
return f"Error updating setting '{key}': {str(e)}" | |
def create_setting_component(key, value): | |
with gr.Accordion(key, open=False): | |
if isinstance(value, (dict, list)): | |
value_str = json.dumps(value, indent=2) | |
lines = value_str.count('\n') + 1 | |
else: | |
value_str = str(value) | |
lines = 1 | |
text_area = gr.TextArea(value=value_str, label="Value", lines=lines, max_lines=20) | |
update_btn = gr.Button("Update", variant="primary") | |
status = gr.Textbox(label="Status", visible=False) | |
update_btn.click( | |
fn=update_setting, | |
inputs=[gr.Textbox(value=key, visible=False), text_area], | |
outputs=[status] | |
).then( | |
fn=lambda: gr.update(visible=True), | |
outputs=[status] | |
) | |
def get_openai_client(): | |
return OpenAI( | |
base_url=os.getenv("LLM_API_BASE"), | |
api_key=os.getenv("LLM_API_KEY"), | |
llm_model = os.getenv("LLM_MODEL") | |
) | |
async def chat_with_openai(messages, model, temperature, max_tokens, api_base): | |
client = AsyncOpenAI( | |
base_url=api_base, | |
api_key=os.getenv("LLM_API_KEY") | |
) | |
try: | |
response = await client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logging.error(f"Error in chat_with_openai: {str(e)}") | |
return f"An error occurred: {str(e)}" | |
return f"Error: {str(e)}" | |
def chat_with_llm(query, history, system_message, temperature, max_tokens, model, api_base): | |
try: | |
messages = [{"role": "system", "content": system_message}] | |
for item in history: | |
if isinstance(item, tuple) and len(item) == 2: | |
human, ai = item | |
messages.append({"role": "user", "content": human}) | |
messages.append({"role": "assistant", "content": ai}) | |
messages.append({"role": "user", "content": query}) | |
logging.info(f"Sending chat request to {api_base} with model {model}") | |
client = OpenAI(base_url=api_base, api_key=os.getenv("LLM_API_KEY", "dummy-key")) | |
response = client.chat.completions.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logging.error(f"Error in chat_with_llm: {str(e)}") | |
logging.error(f"Attempted with model: {model}, api_base: {api_base}") | |
raise RuntimeError(f"Chat request failed: {str(e)}") | |
def run_graphrag_query(cli_args): | |
try: | |
command = ' '.join(cli_args) | |
logging.info(f"Executing command: {command}") | |
result = subprocess.run(cli_args, capture_output=True, text=True, check=True) | |
return result.stdout.strip() | |
except subprocess.CalledProcessError as e: | |
logging.error(f"Error running GraphRAG query: {e}") | |
logging.error(f"Command output (stdout): {e.stdout}") | |
logging.error(f"Command output (stderr): {e.stderr}") | |
raise RuntimeError(f"GraphRAG query failed: {e.stderr}") | |
def parse_query_response(response: str): | |
try: | |
# Split the response into metadata and content | |
parts = response.split("\n\n", 1) | |
if len(parts) < 2: | |
return response # Return original response if it doesn't contain metadata | |
metadata_str, content = parts | |
metadata = json.loads(metadata_str) | |
# Extract relevant information from metadata | |
query_type = metadata.get("query_type", "Unknown") | |
execution_time = metadata.get("execution_time", "N/A") | |
tokens_used = metadata.get("tokens_used", "N/A") | |
# Remove unwanted lines from the content | |
content_lines = content.split('\n') | |
filtered_content = '\n'.join([line for line in content_lines if not line.startswith("INFO:") and not line.startswith("creating llm client")]) | |
# Format the parsed response | |
parsed_response = f""" | |
Query Type: {query_type} | |
Execution Time: {execution_time} seconds | |
Tokens Used: {tokens_used} | |
{filtered_content.strip()} | |
""" | |
return parsed_response | |
except Exception as e: | |
print(f"Error parsing query response: {str(e)}") | |
return response | |
def send_message(query_type, query, history, system_message, temperature, max_tokens, preset, community_level, response_type, custom_cli_args, selected_folder): | |
try: | |
if query_type in ["global", "local"]: | |
cli_args = construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder) | |
logging.info(f"Executing {query_type} search with command: {' '.join(cli_args)}") | |
result = run_graphrag_query(cli_args) | |
parsed_result = parse_query_response(result) | |
logging.info(f"Parsed query result: {parsed_result}") | |
else: # Direct chat | |
llm_model = os.getenv("LLM_MODEL") | |
api_base = os.getenv("LLM_API_BASE") | |
logging.info(f"Executing direct chat with model: {llm_model}") | |
try: | |
result = chat_with_llm(query, history, system_message, temperature, max_tokens, llm_model, api_base) | |
parsed_result = result # No parsing needed for direct chat | |
logging.info(f"Direct chat result: {parsed_result[:100]}...") # Log first 100 chars of result | |
except Exception as chat_error: | |
logging.error(f"Error in chat_with_llm: {str(chat_error)}") | |
raise RuntimeError(f"Direct chat failed: {str(chat_error)}") | |
history.append((query, parsed_result)) | |
except Exception as e: | |
error_message = f"An error occurred: {str(e)}" | |
logging.error(error_message) | |
logging.exception("Exception details:") | |
history.append((query, error_message)) | |
return history, gr.update(value=""), update_logs() | |
def construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder): | |
if not selected_folder: | |
raise ValueError("No folder selected. Please select an output folder before querying.") | |
artifacts_folder = os.path.join("./indexing/output", selected_folder, "artifacts") | |
if not os.path.exists(artifacts_folder): | |
raise ValueError(f"Artifacts folder not found in {artifacts_folder}") | |
base_args = [ | |
"python", "-m", "graphrag.query", | |
"--data", artifacts_folder, | |
"--method", query_type, | |
] | |
# Apply preset configurations | |
if preset.startswith("Default"): | |
base_args.extend(["--community_level", "2", "--response_type", "Multiple Paragraphs"]) | |
elif preset.startswith("Detailed"): | |
base_args.extend(["--community_level", "4", "--response_type", "Multi-Page Report"]) | |
elif preset.startswith("Quick"): | |
base_args.extend(["--community_level", "1", "--response_type", "Single Paragraph"]) | |
elif preset.startswith("Bullet"): | |
base_args.extend(["--community_level", "2", "--response_type", "List of 3-7 Points"]) | |
elif preset.startswith("Comprehensive"): | |
base_args.extend(["--community_level", "5", "--response_type", "Multi-Page Report"]) | |
elif preset.startswith("High-Level"): | |
base_args.extend(["--community_level", "1", "--response_type", "Single Page"]) | |
elif preset.startswith("Focused"): | |
base_args.extend(["--community_level", "3", "--response_type", "Multiple Paragraphs"]) | |
elif preset == "Custom Query": | |
base_args.extend([ | |
"--community_level", str(community_level), | |
"--response_type", f'"{response_type}"', | |
]) | |
if custom_cli_args: | |
base_args.extend(custom_cli_args.split()) | |
# Add the query at the end | |
base_args.append(query) | |
return base_args | |
def upload_file(file): | |
if file is not None: | |
input_dir = os.path.join("indexing", "input") | |
os.makedirs(input_dir, exist_ok=True) | |
# Get the original filename from the uploaded file | |
original_filename = file.name | |
# Create the destination path | |
destination_path = os.path.join(input_dir, os.path.basename(original_filename)) | |
# Move the uploaded file to the destination path | |
shutil.move(file.name, destination_path) | |
logging.info(f"File uploaded and moved to: {destination_path}") | |
status = f"File uploaded: {os.path.basename(original_filename)}" | |
else: | |
status = "No file uploaded" | |
# Get the updated file list | |
updated_file_list = [f["path"] for f in list_input_files()] | |
return status, gr.update(choices=updated_file_list), update_logs() | |
def list_input_files(): | |
input_dir = os.path.join("indexing", "input") | |
files = [] | |
if os.path.exists(input_dir): | |
files = os.listdir(input_dir) | |
return [{"name": f, "path": os.path.join(input_dir, f)} for f in files] | |
def delete_file(file_path): | |
try: | |
os.remove(file_path) | |
logging.info(f"File deleted: {file_path}") | |
status = f"File deleted: {os.path.basename(file_path)}" | |
except Exception as e: | |
logging.error(f"Error deleting file: {str(e)}") | |
status = f"Error deleting file: {str(e)}" | |
# Get the updated file list | |
updated_file_list = [f["path"] for f in list_input_files()] | |
return status, gr.update(choices=updated_file_list), update_logs() | |
def read_file_content(file_path): | |
try: | |
if file_path.endswith('.parquet'): | |
df = pd.read_parquet(file_path) | |
# Get basic information about the DataFrame | |
info = f"Parquet File: {os.path.basename(file_path)}\n" | |
info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n" | |
info += "Column Names:\n" + "\n".join(df.columns) + "\n\n" | |
# Display first few rows | |
info += "First 5 rows:\n" | |
info += df.head().to_string() + "\n\n" | |
# Display basic statistics | |
info += "Basic Statistics:\n" | |
info += df.describe().to_string() | |
return info | |
else: | |
with open(file_path, 'r', encoding='utf-8', errors='replace') as file: | |
content = file.read() | |
return content | |
except Exception as e: | |
logging.error(f"Error reading file: {str(e)}") | |
return f"Error reading file: {str(e)}" | |
def save_file_content(file_path, content): | |
try: | |
with open(file_path, 'w') as file: | |
file.write(content) | |
logging.info(f"File saved: {file_path}") | |
status = f"File saved: {os.path.basename(file_path)}" | |
except Exception as e: | |
logging.error(f"Error saving file: {str(e)}") | |
status = f"Error saving file: {str(e)}" | |
return status, update_logs() | |
def manage_data(): | |
db = lancedb.connect("./indexing/lancedb") | |
tables = db.table_names() | |
table_info = "" | |
if tables: | |
table = db[tables[0]] | |
table_info = f"Table: {tables[0]}\nSchema: {table.schema}" | |
input_files = list_input_files() | |
return { | |
"database_info": f"Tables: {', '.join(tables)}\n\n{table_info}", | |
"input_files": input_files | |
} | |
def find_latest_graph_file(root_dir): | |
pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml") | |
graph_files = glob.glob(pattern) | |
if not graph_files: | |
# If no files found, try excluding .DS_Store | |
output_dir = os.path.join(root_dir, "output") | |
run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"] | |
if run_dirs: | |
latest_run = max(run_dirs) | |
pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml") | |
graph_files = glob.glob(pattern) | |
if not graph_files: | |
return None | |
# Sort files by modification time, most recent first | |
latest_file = max(graph_files, key=os.path.getmtime) | |
return latest_file | |
def update_visualization(folder_name, file_name, layout_type, node_size, edge_width, node_color_attribute, color_scheme, show_labels, label_size): | |
root_dir = "./indexing" | |
if not folder_name or not file_name: | |
return None, "Please select a folder and a GraphML file." | |
file_name = file_name.split("] ")[1] if "]" in file_name else file_name # Remove file type prefix | |
graph_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name) | |
if not graph_path.endswith('.graphml'): | |
return None, "Please select a GraphML file for visualization." | |
try: | |
# Load the GraphML file | |
graph = nx.read_graphml(graph_path) | |
# Create layout based on user selection | |
if layout_type == "3D Spring": | |
pos = nx.spring_layout(graph, dim=3, seed=42, k=0.5) | |
elif layout_type == "2D Spring": | |
pos = nx.spring_layout(graph, dim=2, seed=42, k=0.5) | |
else: # Circular | |
pos = nx.circular_layout(graph) | |
# Extract node positions | |
if layout_type == "3D Spring": | |
x_nodes = [pos[node][0] for node in graph.nodes()] | |
y_nodes = [pos[node][1] for node in graph.nodes()] | |
z_nodes = [pos[node][2] for node in graph.nodes()] | |
else: | |
x_nodes = [pos[node][0] for node in graph.nodes()] | |
y_nodes = [pos[node][1] for node in graph.nodes()] | |
z_nodes = [0] * len(graph.nodes()) # Set all z-coordinates to 0 for 2D layouts | |
# Extract edge positions | |
x_edges, y_edges, z_edges = [], [], [] | |
for edge in graph.edges(): | |
x_edges.extend([pos[edge[0]][0], pos[edge[1]][0], None]) | |
y_edges.extend([pos[edge[0]][1], pos[edge[1]][1], None]) | |
if layout_type == "3D Spring": | |
z_edges.extend([pos[edge[0]][2], pos[edge[1]][2], None]) | |
else: | |
z_edges.extend([0, 0, None]) | |
# Generate node colors based on user selection | |
if node_color_attribute == "Degree": | |
node_colors = [graph.degree(node) for node in graph.nodes()] | |
else: # Random | |
node_colors = [random.random() for _ in graph.nodes()] | |
node_colors = np.array(node_colors) | |
node_colors = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min()) | |
# Create the trace for edges | |
edge_trace = go.Scatter3d( | |
x=x_edges, y=y_edges, z=z_edges, | |
mode='lines', | |
line=dict(color='lightgray', width=edge_width), | |
hoverinfo='none' | |
) | |
# Create the trace for nodes | |
node_trace = go.Scatter3d( | |
x=x_nodes, y=y_nodes, z=z_nodes, | |
mode='markers+text' if show_labels else 'markers', | |
marker=dict( | |
size=node_size, | |
color=node_colors, | |
colorscale=color_scheme, | |
colorbar=dict( | |
title='Node Degree' if node_color_attribute == "Degree" else "Random Value", | |
thickness=10, | |
x=1.1, | |
tickvals=[0, 1], | |
ticktext=['Low', 'High'] | |
), | |
line=dict(width=1) | |
), | |
text=[node for node in graph.nodes()], | |
textposition="top center", | |
textfont=dict(size=label_size, color='black'), | |
hoverinfo='text' | |
) | |
# Create the plot | |
fig = go.Figure(data=[edge_trace, node_trace]) | |
# Update layout for better visualization | |
fig.update_layout( | |
title=f'{layout_type} Graph Visualization: {os.path.basename(graph_path)}', | |
showlegend=False, | |
scene=dict( | |
xaxis=dict(showbackground=False, showticklabels=False, title=''), | |
yaxis=dict(showbackground=False, showticklabels=False, title=''), | |
zaxis=dict(showbackground=False, showticklabels=False, title='') | |
), | |
margin=dict(l=0, r=0, b=0, t=40), | |
annotations=[ | |
dict( | |
showarrow=False, | |
text=f"Interactive {layout_type} visualization of GraphML data", | |
xref="paper", | |
yref="paper", | |
x=0, | |
y=0 | |
) | |
], | |
autosize=True | |
) | |
fig.update_layout(autosize=True) | |
fig.update_layout(height=600) # Set a fixed height | |
return fig, f"Graph visualization generated successfully. Using file: {graph_path}" | |
except Exception as e: | |
return go.Figure(), f"Error visualizing graph: {str(e)}" | |
def update_logs(): | |
logs = [] | |
while not log_queue.empty(): | |
logs.append(log_queue.get()) | |
return "\n".join(logs) | |
def fetch_models(base_url, api_key, service_type): | |
try: | |
if service_type.lower() == "ollama": | |
response = requests.get(f"{base_url}/tags", timeout=10) | |
else: # OpenAI Compatible | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
response = requests.get(f"{base_url}/models", headers=headers, timeout=10) | |
logging.info(f"Raw API response: {response.text}") | |
if response.status_code == 200: | |
data = response.json() | |
if service_type.lower() == "ollama": | |
models = [model.get('name', '') for model in data.get('models', data) if isinstance(model, dict)] | |
else: # OpenAI Compatible | |
models = [model.get('id', '') for model in data.get('data', []) if isinstance(model, dict)] | |
models = [model for model in models if model] # Remove empty strings | |
if not models: | |
logging.warning(f"No models found in {service_type} API response") | |
return ["No models available"] | |
logging.info(f"Successfully fetched {service_type} models: {models}") | |
return models | |
else: | |
logging.error(f"Error fetching {service_type} models. Status code: {response.status_code}, Response: {response.text}") | |
return ["Error fetching models"] | |
except requests.RequestException as e: | |
logging.error(f"Exception while fetching {service_type} models: {str(e)}") | |
return ["Error: Connection failed"] | |
except Exception as e: | |
logging.error(f"Unexpected error in fetch_models: {str(e)}") | |
return ["Error: Unexpected issue"] | |
def update_model_choices(base_url, api_key, service_type, settings_key): | |
models = fetch_models(base_url, api_key, service_type) | |
if not models: | |
logging.warning(f"No models fetched for {service_type}.") | |
# Get the current model from settings | |
current_model = settings.get(settings_key, {}).get('llm', {}).get('model') | |
# If the current model is not in the list, add it | |
if current_model and current_model not in models: | |
models.append(current_model) | |
return gr.update(choices=models, value=current_model if current_model in models else (models[0] if models else None)) | |
def update_llm_model_choices(base_url, api_key, service_type): | |
return update_model_choices(base_url, api_key, service_type, 'llm') | |
def update_embeddings_model_choices(base_url, api_key, service_type): | |
return update_model_choices(base_url, api_key, service_type, 'embeddings') | |
def update_llm_settings(llm_model, embeddings_model, context_window, system_message, temperature, max_tokens, | |
llm_api_base, llm_api_key, | |
embeddings_api_base, embeddings_api_key, embeddings_service_type): | |
try: | |
# Update settings.yaml | |
settings = load_settings() | |
settings['llm'].update({ | |
"type": "openai", # Always set to "openai" since we removed the radio button | |
"model": llm_model, | |
"api_base": llm_api_base, | |
"api_key": "${GRAPHRAG_API_KEY}", | |
"temperature": temperature, | |
"max_tokens": max_tokens, | |
"provider": "openai_chat" # Always set to "openai_chat" | |
}) | |
settings['embeddings']['llm'].update({ | |
"type": "openai_embedding", # Always use OpenAIEmbeddingsLLM | |
"model": embeddings_model, | |
"api_base": embeddings_api_base, | |
"api_key": "${GRAPHRAG_API_KEY}", | |
"provider": embeddings_service_type | |
}) | |
with open("indexing/settings.yaml", 'w') as f: | |
yaml.dump(settings, f, default_flow_style=False) | |
# Update .env file | |
update_env_file("LLM_API_BASE", llm_api_base) | |
update_env_file("LLM_API_KEY", llm_api_key) | |
update_env_file("LLM_MODEL", llm_model) | |
update_env_file("EMBEDDINGS_API_BASE", embeddings_api_base) | |
update_env_file("EMBEDDINGS_API_KEY", embeddings_api_key) | |
update_env_file("EMBEDDINGS_MODEL", embeddings_model) | |
update_env_file("CONTEXT_WINDOW", str(context_window)) | |
update_env_file("SYSTEM_MESSAGE", system_message) | |
update_env_file("TEMPERATURE", str(temperature)) | |
update_env_file("MAX_TOKENS", str(max_tokens)) | |
update_env_file("LLM_SERVICE_TYPE", "openai_chat") | |
update_env_file("EMBEDDINGS_SERVICE_TYPE", embeddings_service_type) | |
# Reload environment variables | |
load_dotenv(override=True) | |
return "LLM and embeddings settings updated successfully in both settings.yaml and .env files." | |
except Exception as e: | |
return f"Error updating LLM and embeddings settings: {str(e)}" | |
def update_env_file(key, value): | |
env_path = 'indexing/.env' | |
with open(env_path, 'r') as file: | |
lines = file.readlines() | |
updated = False | |
for i, line in enumerate(lines): | |
if line.startswith(f"{key}="): | |
lines[i] = f"{key}={value}\n" | |
updated = True | |
break | |
if not updated: | |
lines.append(f"{key}={value}\n") | |
with open(env_path, 'w') as file: | |
file.writelines(lines) | |
custom_css = """ | |
html, body { | |
margin: 0; | |
padding: 0; | |
height: 100vh; | |
overflow: hidden; | |
} | |
.gradio-container { | |
margin: 0 !important; | |
padding: 0 !important; | |
width: 100vw !important; | |
max-width: 100vw !important; | |
height: 100vh !important; | |
max-height: 100vh !important; | |
overflow: auto; | |
display: flex; | |
flex-direction: column; | |
} | |
#main-container { | |
flex: 1; | |
display: flex; | |
overflow: hidden; | |
} | |
#left-column, #right-column { | |
height: 100%; | |
overflow-y: auto; | |
padding: 10px; | |
} | |
#left-column { | |
flex: 1; | |
} | |
#right-column { | |
flex: 2; | |
display: flex; | |
flex-direction: column; | |
} | |
#chat-container { | |
flex: 0 0 auto; /* Don't allow this to grow */ | |
height: 100%; | |
display: flex; | |
flex-direction: column; | |
overflow: hidden; | |
border: 1px solid var(--color-accent); | |
border-radius: 8px; | |
padding: 10px; | |
overflow-y: auto; | |
} | |
#chatbot { | |
overflow-y: hidden; | |
height: 100%; | |
} | |
#chat-input-row { | |
margin-top: 10px; | |
} | |
#visualization-plot { | |
width: 100%; | |
aspect-ratio: 1 / 1; | |
max-height: 600px; /* Adjust this value as needed */ | |
} | |
#vis-controls-row { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin-top: 10px; | |
} | |
#vis-controls-row > * { | |
flex: 1; | |
margin: 0 5px; | |
} | |
#vis-status { | |
margin-top: 10px; | |
} | |
/* Chat input styling */ | |
#chat-input-row { | |
display: flex; | |
flex-direction: column; | |
} | |
#chat-input-row > div { | |
width: 100% !important; | |
} | |
#chat-input-row input[type="text"] { | |
width: 100% !important; | |
} | |
/* Adjust padding for all containers */ | |
.gr-box, .gr-form, .gr-panel { | |
padding: 10px !important; | |
} | |
/* Ensure all textboxes and textareas have full height */ | |
.gr-textbox, .gr-textarea { | |
height: auto !important; | |
min-height: 100px !important; | |
} | |
/* Ensure all dropdowns have full width */ | |
.gr-dropdown { | |
width: 100% !important; | |
} | |
:root { | |
--color-background: #2C3639; | |
--color-foreground: #3F4E4F; | |
--color-accent: #A27B5C; | |
--color-text: #DCD7C9; | |
} | |
body, .gradio-container { | |
background-color: var(--color-background); | |
color: var(--color-text); | |
} | |
.gr-button { | |
background-color: var(--color-accent); | |
color: var(--color-text); | |
} | |
.gr-input, .gr-textarea, .gr-dropdown { | |
background-color: var(--color-foreground); | |
color: var(--color-text); | |
border: 1px solid var(--color-accent); | |
} | |
.gr-panel { | |
background-color: var(--color-foreground); | |
border: 1px solid var(--color-accent); | |
} | |
.gr-box { | |
border-radius: 8px; | |
margin-bottom: 10px; | |
background-color: var(--color-foreground); | |
} | |
.gr-padded { | |
padding: 10px; | |
} | |
.gr-form { | |
background-color: var(--color-foreground); | |
} | |
.gr-input-label, .gr-radio-label { | |
color: var(--color-text); | |
} | |
.gr-checkbox-label { | |
color: var(--color-text); | |
} | |
.gr-markdown { | |
color: var(--color-text); | |
} | |
.gr-accordion { | |
background-color: var(--color-foreground); | |
border: 1px solid var(--color-accent); | |
} | |
.gr-accordion-header { | |
background-color: var(--color-accent); | |
color: var(--color-text); | |
} | |
#visualization-container { | |
display: flex; | |
flex-direction: column; | |
border: 2px solid var(--color-accent); | |
border-radius: 8px; | |
margin-top: 20px; | |
padding: 10px; | |
background-color: var(--color-foreground); | |
height: calc(100vh - 300px); /* Adjust this value as needed */ | |
} | |
#visualization-plot { | |
width: 100%; | |
height: 100%; | |
} | |
#vis-controls-row { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
margin-top: 10px; | |
} | |
#vis-controls-row > * { | |
flex: 1; | |
margin: 0 5px; | |
} | |
#vis-status { | |
margin-top: 10px; | |
} | |
#log-container { | |
background-color: var(--color-foreground); | |
border: 1px solid var(--color-accent); | |
border-radius: 8px; | |
padding: 10px; | |
margin-top: 20px; | |
max-height: auto; | |
overflow-y: auto; | |
} | |
.setting-accordion .label-wrap { | |
cursor: pointer; | |
} | |
.setting-accordion .icon { | |
transition: transform 0.3s ease; | |
} | |
.setting-accordion[open] .icon { | |
transform: rotate(90deg); | |
} | |
.gr-form.gr-box { | |
border: none !important; | |
background: none !important; | |
} | |
.model-params { | |
border-top: 1px solid var(--color-accent); | |
margin-top: 10px; | |
padding-top: 10px; | |
} | |
""" | |
def list_output_files(root_dir): | |
output_dir = os.path.join(root_dir, "output") | |
files = [] | |
for root, _, filenames in os.walk(output_dir): | |
for filename in filenames: | |
files.append(os.path.join(root, filename)) | |
return files | |
def update_file_list(): | |
files = list_input_files() | |
return gr.update(choices=[f["path"] for f in files]) | |
def update_file_content(file_path): | |
if not file_path: | |
return "" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
content = file.read() | |
return content | |
except Exception as e: | |
logging.error(f"Error reading file: {str(e)}") | |
return f"Error reading file: {str(e)}" | |
def list_output_folders(root_dir): | |
output_dir = os.path.join(root_dir, "output") | |
folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))] | |
return sorted(folders, reverse=True) | |
def list_folder_contents(folder_path): | |
contents = [] | |
for item in os.listdir(folder_path): | |
item_path = os.path.join(folder_path, item) | |
if os.path.isdir(item_path): | |
contents.append(f"[DIR] {item}") | |
else: | |
_, ext = os.path.splitext(item) | |
contents.append(f"[{ext[1:].upper()}] {item}") | |
return contents | |
def update_output_folder_list(): | |
root_dir = "./" | |
folders = list_output_folders(root_dir) | |
return gr.update(choices=folders, value=folders[0] if folders else None) | |
def update_folder_content_list(folder_name): | |
root_dir = "./" | |
if not folder_name: | |
return gr.update(choices=[]) | |
contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, "artifacts")) | |
return gr.update(choices=contents) | |
def handle_content_selection(folder_name, selected_item): | |
root_dir = "./" | |
if isinstance(selected_item, list) and selected_item: | |
selected_item = selected_item[0] # Take the first item if it's a list | |
if isinstance(selected_item, str) and selected_item.startswith("[DIR]"): | |
dir_name = selected_item[6:] # Remove "[DIR] " prefix | |
sub_contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, dir_name)) | |
return gr.update(choices=sub_contents), "", "" | |
elif isinstance(selected_item, str): | |
file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item # Remove file type prefix if present | |
file_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name) | |
file_size = os.path.getsize(file_path) | |
file_type = os.path.splitext(file_name)[1] | |
file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}" | |
content = read_file_content(file_path) | |
return gr.update(), file_info, content | |
else: | |
return gr.update(), "", "" | |
def initialize_selected_folder(folder_name): | |
root_dir = "./" | |
if not folder_name: | |
return "Please select a folder first.", gr.update(choices=[]) | |
folder_path = os.path.join(root_dir, "output", folder_name, "artifacts") | |
if not os.path.exists(folder_path): | |
return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[]) | |
contents = list_folder_contents(folder_path) | |
return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents) | |
settings = load_settings() | |
default_model = settings['llm']['model'] | |
cli_args = gr.State({}) | |
stop_indexing = threading.Event() | |
indexing_thread = None | |
def start_indexing(*args): | |
global indexing_thread, stop_indexing | |
stop_indexing = threading.Event() # Reset the stop_indexing event | |
indexing_thread = threading.Thread(target=run_indexing, args=args) | |
indexing_thread.start() | |
return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False) | |
def stop_indexing_process(): | |
global indexing_thread | |
logging.info("Stop indexing requested") | |
stop_indexing.set() | |
if indexing_thread and indexing_thread.is_alive(): | |
logging.info("Waiting for indexing thread to finish") | |
indexing_thread.join(timeout=10) | |
logging.info("Indexing thread finished" if not indexing_thread.is_alive() else "Indexing thread did not finish within timeout") | |
indexing_thread = None # Reset the thread | |
return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True) | |
def refresh_indexing(): | |
global indexing_thread, stop_indexing | |
if indexing_thread and indexing_thread.is_alive(): | |
logging.info("Cannot refresh: Indexing is still running") | |
return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), "Cannot refresh: Indexing is still running" | |
else: | |
stop_indexing = threading.Event() # Reset the stop_indexing event | |
indexing_thread = None # Reset the thread | |
return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True), "Indexing process refreshed. You can start indexing again." | |
def run_indexing(root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_args): | |
cmd = ["python", "-m", "graphrag.index", "--root", "./indexing"] | |
# Add custom CLI arguments | |
if custom_args: | |
cmd.extend(custom_args.split()) | |
logging.info(f"Executing command: {' '.join(cmd)}") | |
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, encoding='utf-8', universal_newlines=True) | |
output = [] | |
progress_value = 0 | |
iterations_completed = 0 | |
while True: | |
if stop_indexing.is_set(): | |
process.terminate() | |
process.wait(timeout=5) | |
if process.poll() is None: | |
process.kill() | |
return ("\n".join(output + ["Indexing stopped by user."]), | |
"Indexing stopped.", | |
100, | |
gr.update(interactive=True), | |
gr.update(interactive=False), | |
gr.update(interactive=True), | |
str(iterations_completed)) | |
try: | |
line = process.stdout.readline() | |
if not line and process.poll() is not None: | |
break | |
if line: | |
line = line.strip() | |
output.append(line) | |
if "Processing file" in line: | |
progress_value += 1 | |
iterations_completed += 1 | |
elif "Indexing completed" in line: | |
progress_value = 100 | |
elif "ERROR" in line: | |
line = f"🚨 ERROR: {line}" | |
yield ("\n".join(output), | |
line, | |
progress_value, | |
gr.update(interactive=False), | |
gr.update(interactive=True), | |
gr.update(interactive=False), | |
str(iterations_completed)) | |
except Exception as e: | |
logging.error(f"Error during indexing: {str(e)}") | |
return ("\n".join(output + [f"Error: {str(e)}"]), | |
"Error occurred during indexing.", | |
100, | |
gr.update(interactive=True), | |
gr.update(interactive=False), | |
gr.update(interactive=True), | |
str(iterations_completed)) | |
if process.returncode != 0 and not stop_indexing.is_set(): | |
final_output = "\n".join(output + [f"Error: Process exited with return code {process.returncode}"]) | |
final_progress = "Indexing failed. Check output for details." | |
else: | |
final_output = "\n".join(output) | |
final_progress = "Indexing completed successfully!" | |
return (final_output, | |
final_progress, | |
100, | |
gr.update(interactive=True), | |
gr.update(interactive=False), | |
gr.update(interactive=True), | |
str(iterations_completed)) | |
global_vector_store_wrapper = None | |
def create_gradio_interface(): | |
global global_vector_store_wrapper | |
llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder = initialize_models() | |
settings = load_settings() | |
log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False, visible=False) | |
with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo: | |
gr.Markdown("# GraphRAG Local UI", elem_id="title") | |
with gr.Row(elem_id="main-container"): | |
with gr.Column(scale=1, elem_id="left-column"): | |
with gr.Tabs(): | |
with gr.TabItem("Data Management"): | |
with gr.Accordion("File Upload (.txt)", open=True): | |
file_upload = gr.File(label="Upload .txt File", file_types=[".txt"]) | |
upload_btn = gr.Button("Upload File", variant="primary") | |
upload_output = gr.Textbox(label="Upload Status", visible=False) | |
with gr.Accordion("File Management", open=True): | |
file_list = gr.Dropdown(label="Select File", choices=[], interactive=True) | |
refresh_btn = gr.Button("Refresh File List", variant="secondary") | |
file_content = gr.TextArea(label="File Content", lines=10) | |
with gr.Row(): | |
delete_btn = gr.Button("Delete Selected File", variant="stop") | |
save_btn = gr.Button("Save Changes", variant="primary") | |
operation_status = gr.Textbox(label="Operation Status", visible=False) | |
with gr.TabItem("Indexing"): | |
root_dir = gr.Textbox(label="Root Directory", value="./") | |
config_file = gr.File(label="Config File (optional)") | |
with gr.Row(): | |
verbose = gr.Checkbox(label="Verbose", value=True) | |
nocache = gr.Checkbox(label="No Cache", value=True) | |
with gr.Row(): | |
resume = gr.Textbox(label="Resume Timestamp (optional)") | |
reporter = gr.Dropdown(label="Reporter", choices=["rich", "print", "none"], value=None) | |
with gr.Row(): | |
emit_formats = gr.CheckboxGroup(label="Emit Formats", choices=["json", "csv", "parquet"], value=None) | |
with gr.Row(): | |
run_index_button = gr.Button("Run Indexing") | |
stop_index_button = gr.Button("Stop Indexing", variant="stop") | |
refresh_index_button = gr.Button("Refresh Indexing", variant="secondary") | |
with gr.Accordion("Custom CLI Arguments", open=True): | |
custom_cli_args = gr.Textbox( | |
label="Custom CLI Arguments", | |
placeholder="--arg1 value1 --arg2 value2", | |
lines=3 | |
) | |
cli_guide = gr.Markdown( | |
textwrap.dedent(""" | |
### CLI Argument Key Guide: | |
- `--root <path>`: Set the root directory for the project | |
- `--config <path>`: Specify a custom configuration file | |
- `--verbose`: Enable verbose output | |
- `--nocache`: Disable caching | |
- `--resume <timestamp>`: Resume from a specific timestamp | |
- `--reporter <type>`: Set the reporter type (rich, print, none) | |
- `--emit <formats>`: Specify output formats (json, csv, parquet) | |
Example: `--verbose --nocache --emit json,csv` | |
""") | |
) | |
index_output = gr.Textbox(label="Indexing Output", lines=20, max_lines=30) | |
index_progress = gr.Textbox(label="Indexing Progress", lines=3) | |
iterations_completed = gr.Textbox(label="Iterations Completed", value="0") | |
refresh_status = gr.Textbox(label="Refresh Status", visible=True) | |
run_index_button.click( | |
fn=start_indexing, | |
inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args], | |
outputs=[run_index_button, stop_index_button, refresh_index_button] | |
).then( | |
fn=run_indexing, | |
inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args], | |
outputs=[index_output, index_progress, run_index_button, stop_index_button, refresh_index_button, iterations_completed] | |
) | |
stop_index_button.click( | |
fn=stop_indexing_process, | |
outputs=[run_index_button, stop_index_button, refresh_index_button] | |
) | |
refresh_index_button.click( | |
fn=refresh_indexing, | |
outputs=[run_index_button, stop_index_button, refresh_index_button, refresh_status] | |
) | |
with gr.TabItem("Indexing Outputs/Visuals"): | |
output_folder_list = gr.Dropdown(label="Select Output Folder (Select GraphML File to Visualize)", choices=list_output_folders("./indexing"), interactive=True) | |
refresh_folder_btn = gr.Button("Refresh Folder List", variant="secondary") | |
initialize_folder_btn = gr.Button("Initialize Selected Folder", variant="primary") | |
folder_content_list = gr.Dropdown(label="Select File or Directory", choices=[], interactive=True) | |
file_info = gr.Textbox(label="File Information", interactive=False) | |
output_content = gr.TextArea(label="File Content", lines=20, interactive=False) | |
initialization_status = gr.Textbox(label="Initialization Status") | |
with gr.TabItem("LLM Settings"): | |
llm_base_url = gr.Textbox(label="LLM API Base URL", value=os.getenv("LLM_API_BASE")) | |
llm_api_key = gr.Textbox(label="LLM API Key", value=os.getenv("LLM_API_KEY"), type="password") | |
llm_service_type = gr.Radio( | |
label="LLM Service Type", | |
choices=["openai", "ollama"], | |
value="openai", | |
visible=False # Hide this if you want to always use OpenAI | |
) | |
llm_model_dropdown = gr.Dropdown( | |
label="LLM Model", | |
choices=[], # Start with an empty list | |
value=settings['llm'].get('model'), | |
allow_custom_value=True | |
) | |
refresh_llm_models_btn = gr.Button("Refresh LLM Models", variant="secondary") | |
embeddings_base_url = gr.Textbox(label="Embeddings API Base URL", value=os.getenv("EMBEDDINGS_API_BASE")) | |
embeddings_api_key = gr.Textbox(label="Embeddings API Key", value=os.getenv("EMBEDDINGS_API_KEY"), type="password") | |
embeddings_service_type = gr.Radio( | |
label="Embeddings Service Type", | |
choices=["openai", "ollama"], | |
value=settings.get('embeddings', {}).get('llm', {}).get('type', 'openai'), | |
visible=False, | |
) | |
embeddings_model_dropdown = gr.Dropdown( | |
label="Embeddings Model", | |
choices=[], | |
value=settings.get('embeddings', {}).get('llm', {}).get('model'), | |
allow_custom_value=True | |
) | |
refresh_embeddings_models_btn = gr.Button("Refresh Embedding Models", variant="secondary") | |
system_message = gr.Textbox( | |
lines=5, | |
label="System Message", | |
value=os.getenv("SYSTEM_MESSAGE", "You are a helpful AI assistant.") | |
) | |
context_window = gr.Slider( | |
label="Context Window", | |
minimum=512, | |
maximum=32768, | |
step=512, | |
value=int(os.getenv("CONTEXT_WINDOW", 4096)) | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.0, | |
maximum=2.0, | |
step=0.1, | |
value=float(settings['llm'].get('TEMPERATURE', 0.5)) | |
) | |
max_tokens = gr.Slider( | |
label="Max Tokens", | |
minimum=1, | |
maximum=8192, | |
step=1, | |
value=int(settings['llm'].get('MAX_TOKENS', 1024)) | |
) | |
update_settings_btn = gr.Button("Update LLM Settings", variant="primary") | |
llm_settings_status = gr.Textbox(label="Status", interactive=False) | |
llm_base_url.change( | |
fn=update_model_choices, | |
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], | |
outputs=llm_model_dropdown | |
) | |
# Update Embeddings model choices when service type or base URL changes | |
embeddings_service_type.change( | |
fn=update_embeddings_model_choices, | |
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type], | |
outputs=embeddings_model_dropdown | |
) | |
embeddings_base_url.change( | |
fn=update_model_choices, | |
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], | |
outputs=embeddings_model_dropdown | |
) | |
update_settings_btn.click( | |
fn=update_llm_settings, | |
inputs=[ | |
llm_model_dropdown, | |
embeddings_model_dropdown, | |
context_window, | |
system_message, | |
temperature, | |
max_tokens, | |
llm_base_url, | |
llm_api_key, | |
embeddings_base_url, | |
embeddings_api_key, | |
embeddings_service_type | |
], | |
outputs=[llm_settings_status] | |
) | |
refresh_llm_models_btn.click( | |
fn=update_model_choices, | |
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], | |
outputs=[llm_model_dropdown] | |
).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
refresh_embeddings_models_btn.click( | |
fn=update_model_choices, | |
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], | |
outputs=[embeddings_model_dropdown] | |
).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
with gr.TabItem("YAML Settings"): | |
settings = load_settings() | |
with gr.Group(): | |
for key, value in settings.items(): | |
if key != 'llm': | |
create_setting_component(key, value) | |
with gr.Group(elem_id="log-container"): | |
gr.Markdown("### Logs") | |
log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False) | |
with gr.Column(scale=2, elem_id="right-column"): | |
with gr.Group(elem_id="chat-container"): | |
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot") | |
with gr.Row(elem_id="chat-input-row"): | |
with gr.Column(scale=1): | |
query_input = gr.Textbox( | |
label="Input", | |
placeholder="Enter your query here...", | |
elem_id="query-input" | |
) | |
query_btn = gr.Button("Send Query", variant="primary") | |
with gr.Accordion("Query Parameters", open=True): | |
query_type = gr.Radio( | |
["global", "local", "direct"], | |
label="Query Type", | |
value="global", | |
info="Global: community-based search, Local: entity-based search, Direct: LLM chat" | |
) | |
preset_dropdown = gr.Dropdown( | |
label="Preset Query Options", | |
choices=[ | |
"Default Global Search", | |
"Default Local Search", | |
"Detailed Global Analysis", | |
"Detailed Local Analysis", | |
"Quick Global Summary", | |
"Quick Local Summary", | |
"Global Bullet Points", | |
"Local Bullet Points", | |
"Comprehensive Global Report", | |
"Comprehensive Local Report", | |
"High-Level Global Overview", | |
"High-Level Local Overview", | |
"Focused Global Insight", | |
"Focused Local Insight", | |
"Custom Query" | |
], | |
value="Default Global Search", | |
info="Select a preset or choose 'Custom Query' for manual configuration" | |
) | |
selected_folder = gr.Dropdown( | |
label="Select Index Folder to Chat With", | |
choices=list_output_folders("./indexing"), | |
value=None, | |
interactive=True | |
) | |
refresh_folder_btn = gr.Button("Refresh Folders", variant="secondary") | |
clear_chat_btn = gr.Button("Clear Chat", variant="secondary") | |
with gr.Group(visible=False) as custom_options: | |
community_level = gr.Slider( | |
label="Community Level", | |
minimum=1, | |
maximum=10, | |
value=2, | |
step=1, | |
info="Higher values use reports on smaller communities" | |
) | |
response_type = gr.Dropdown( | |
label="Response Type", | |
choices=[ | |
"Multiple Paragraphs", | |
"Single Paragraph", | |
"Single Sentence", | |
"List of 3-7 Points", | |
"Single Page", | |
"Multi-Page Report" | |
], | |
value="Multiple Paragraphs", | |
info="Specify the desired format of the response" | |
) | |
custom_cli_args = gr.Textbox( | |
label="Custom CLI Arguments", | |
placeholder="--arg1 value1 --arg2 value2", | |
info="Additional CLI arguments for advanced users" | |
) | |
def update_custom_options(preset): | |
if preset == "Custom Query": | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
preset_dropdown.change(fn=update_custom_options, inputs=[preset_dropdown], outputs=[custom_options]) | |
with gr.Group(elem_id="visualization-container"): | |
vis_output = gr.Plot(label="Graph Visualization", elem_id="visualization-plot") | |
with gr.Row(elem_id="vis-controls-row"): | |
vis_btn = gr.Button("Visualize Graph", variant="secondary") | |
# Add new controls for customization | |
with gr.Accordion("Visualization Settings", open=False): | |
layout_type = gr.Dropdown(["3D Spring", "2D Spring", "Circular"], label="Layout Type", value="3D Spring") | |
node_size = gr.Slider(1, 20, 7, label="Node Size", step=1) | |
edge_width = gr.Slider(0.1, 5, 0.5, label="Edge Width", step=0.1) | |
node_color_attribute = gr.Dropdown(["Degree", "Random"], label="Node Color Attribute", value="Degree") | |
color_scheme = gr.Dropdown(["Viridis", "Plasma", "Inferno", "Magma", "Cividis"], label="Color Scheme", value="Viridis") | |
show_labels = gr.Checkbox(label="Show Node Labels", value=True) | |
label_size = gr.Slider(5, 20, 10, label="Label Size", step=1) | |
# Event handlers | |
upload_btn.click(fn=upload_file, inputs=[file_upload], outputs=[upload_output, file_list, log_output]) | |
refresh_btn.click(fn=update_file_list, outputs=[file_list]).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
file_list.change(fn=update_file_content, inputs=[file_list], outputs=[file_content]).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
delete_btn.click(fn=delete_file, inputs=[file_list], outputs=[operation_status, file_list, log_output]) | |
save_btn.click(fn=save_file_content, inputs=[file_list, file_content], outputs=[operation_status, log_output]) | |
refresh_folder_btn.click( | |
fn=lambda: gr.update(choices=list_output_folders("./indexing")), | |
outputs=[selected_folder] | |
) | |
clear_chat_btn.click( | |
fn=lambda: ([], ""), | |
outputs=[chatbot, query_input] | |
) | |
refresh_folder_btn.click( | |
fn=update_output_folder_list, | |
outputs=[output_folder_list] | |
).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
output_folder_list.change( | |
fn=update_folder_content_list, | |
inputs=[output_folder_list], | |
outputs=[folder_content_list] | |
).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
folder_content_list.change( | |
fn=handle_content_selection, | |
inputs=[output_folder_list, folder_content_list], | |
outputs=[folder_content_list, file_info, output_content] | |
).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
initialize_folder_btn.click( | |
fn=initialize_selected_folder, | |
inputs=[output_folder_list], | |
outputs=[initialization_status, folder_content_list] | |
).then( | |
fn=update_logs, | |
outputs=[log_output] | |
) | |
vis_btn.click( | |
fn=update_visualization, | |
inputs=[ | |
output_folder_list, | |
folder_content_list, | |
layout_type, | |
node_size, | |
edge_width, | |
node_color_attribute, | |
color_scheme, | |
show_labels, | |
label_size | |
], | |
outputs=[vis_output, gr.Textbox(label="Visualization Status")] | |
) | |
query_btn.click( | |
fn=send_message, | |
inputs=[ | |
query_type, | |
query_input, | |
chatbot, | |
system_message, | |
temperature, | |
max_tokens, | |
preset_dropdown, | |
community_level, | |
response_type, | |
custom_cli_args, | |
selected_folder | |
], | |
outputs=[chatbot, query_input, log_output] | |
) | |
query_input.submit( | |
fn=send_message, | |
inputs=[ | |
query_type, | |
query_input, | |
chatbot, | |
system_message, | |
temperature, | |
max_tokens, | |
preset_dropdown, | |
community_level, | |
response_type, | |
custom_cli_args, | |
selected_folder | |
], | |
outputs=[chatbot, query_input, log_output] | |
) | |
refresh_llm_models_btn.click( | |
fn=update_model_choices, | |
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)], | |
outputs=[llm_model_dropdown] | |
) | |
# Update Embeddings model choices | |
refresh_embeddings_models_btn.click( | |
fn=update_model_choices, | |
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)], | |
outputs=[embeddings_model_dropdown] | |
) | |
# Add this JavaScript to enable Shift+Enter functionality | |
demo.load(js=""" | |
function addShiftEnterListener() { | |
const queryInput = document.getElementById('query-input'); | |
if (queryInput) { | |
queryInput.addEventListener('keydown', function(event) { | |
if (event.key === 'Enter' && event.shiftKey) { | |
event.preventDefault(); | |
const submitButton = queryInput.closest('.gradio-container').querySelector('button.primary'); | |
if (submitButton) { | |
submitButton.click(); | |
} | |
} | |
}); | |
} | |
} | |
document.addEventListener('DOMContentLoaded', addShiftEnterListener); | |
""") | |
return demo.queue() | |
async def main(): | |
api_port = 8088 | |
gradio_port = 7860 | |
print(f"Starting API server on port {api_port}") | |
start_api_server(api_port) | |
# Wait for the API server to start in a separate thread | |
threading.Thread(target=wait_for_api_server, args=(api_port,)).start() | |
# Create the Gradio app | |
demo = create_gradio_interface() | |
print(f"Starting Gradio app on port {gradio_port}") | |
# Launch the Gradio app | |
demo.launch(server_port=gradio_port, share=True) | |
demo = create_gradio_interface() | |
app = demo.app | |
if __name__ == "__main__": | |
initialize_data() | |
demo.launch(server_port=7860, share=True) | |