sichaolong's picture
Upload folder using huggingface_hub
e331e72 verified
import gradio as gr
from gradio.helpers import Progress
import asyncio
import subprocess
import yaml
import os
import networkx as nx
import plotly.graph_objects as go
import numpy as np
import plotly.io as pio
import lancedb
import random
import io
import shutil
import logging
import queue
import threading
import time
from collections import deque
import re
import glob
from datetime import datetime
import json
import requests
import aiohttp
from openai import OpenAI
from openai import AsyncOpenAI
import pyarrow.parquet as pq
import pandas as pd
import sys
import colorsys
from dotenv import load_dotenv, set_key
import argparse
import socket
import tiktoken
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.llm.openai import create_openai_chat_llm
from graphrag.llm.openai.factories import create_openai_embedding_llm
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.llm.openai.openai_configuration import OpenAIConfiguration
from graphrag.llm.openai.openai_embeddings_llm import OpenAIEmbeddingsLLM
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
import textwrap
# Suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gradio_client.documentation")
load_dotenv('indexing/.env')
# Set default values for API-related environment variables
os.environ.setdefault("LLM_API_BASE", os.getenv("LLM_API_BASE"))
os.environ.setdefault("LLM_API_KEY", os.getenv("LLM_API_KEY"))
os.environ.setdefault("LLM_MODEL", os.getenv("LLM_MODEL"))
os.environ.setdefault("EMBEDDINGS_API_BASE", os.getenv("EMBEDDINGS_API_BASE"))
os.environ.setdefault("EMBEDDINGS_API_KEY", os.getenv("EMBEDDINGS_API_KEY"))
os.environ.setdefault("EMBEDDINGS_MODEL", os.getenv("EMBEDDINGS_MODEL"))
# Add the project root to the Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.insert(0, project_root)
# Set up logging
log_queue = queue.Queue()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
llm = None
text_embedder = None
class QueueHandler(logging.Handler):
def __init__(self, log_queue):
super().__init__()
self.log_queue = log_queue
def emit(self, record):
self.log_queue.put(self.format(record))
queue_handler = QueueHandler(log_queue)
logging.getLogger().addHandler(queue_handler)
def initialize_models():
global llm, text_embedder
llm_api_base = os.getenv("LLM_API_BASE")
llm_api_key = os.getenv("LLM_API_KEY")
embeddings_api_base = os.getenv("EMBEDDINGS_API_BASE")
embeddings_api_key = os.getenv("EMBEDDINGS_API_KEY")
llm_service_type = os.getenv("LLM_SERVICE_TYPE", "openai_chat").lower() # Provide a default and lower it
embeddings_service_type = os.getenv("EMBEDDINGS_SERVICE_TYPE", "openai").lower() # Provide a default and lower it
llm_model = os.getenv("LLM_MODEL")
embeddings_model = os.getenv("EMBEDDINGS_MODEL")
logging.info("Fetching models...")
models = fetch_models(llm_api_base, llm_api_key, llm_service_type)
# Use the same models list for both LLM and embeddings
llm_models = models
embeddings_models = models
# Initialize LLM
if llm_service_type == "openai_chat":
llm = ChatOpenAI(
api_key=llm_api_key,
api_base=f"{llm_api_base}/v1",
model=llm_model,
api_type=OpenaiApiType.OpenAI,
max_retries=20,
)
# Initialize OpenAI client for embeddings
openai_client = OpenAI(
api_key=embeddings_api_key or "dummy_key",
base_url=f"{embeddings_api_base}/v1"
)
# Initialize text embedder using OpenAIEmbeddingsLLM
text_embedder = OpenAIEmbeddingsLLM(
client=openai_client,
configuration={
"model": embeddings_model,
"api_type": "open_ai",
"api_base": embeddings_api_base,
"api_key": embeddings_api_key or None,
"provider": embeddings_service_type
}
)
return llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder
def find_latest_output_folder():
root_dir = "./indexing/output"
folders = [f for f in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, f))]
if not folders:
raise ValueError("No output folders found")
# Sort folders by creation time, most recent first
sorted_folders = sorted(folders, key=lambda x: os.path.getctime(os.path.join(root_dir, x)), reverse=True)
latest_folder = None
timestamp = None
for folder in sorted_folders:
try:
# Try to parse the folder name as a timestamp
timestamp = datetime.strptime(folder, "%Y%m%d-%H%M%S")
latest_folder = folder
break
except ValueError:
# If the folder name is not a valid timestamp, skip it
continue
if latest_folder is None:
raise ValueError("No valid timestamp folders found")
latest_path = os.path.join(root_dir, latest_folder)
artifacts_path = os.path.join(latest_path, "artifacts")
if not os.path.exists(artifacts_path):
raise ValueError(f"Artifacts folder not found in {latest_path}")
return latest_path, latest_folder
def initialize_data():
global entity_df, relationship_df, text_unit_df, report_df, covariate_df
tables = {
"entity_df": "create_final_nodes",
"relationship_df": "create_final_edges",
"text_unit_df": "create_final_text_units",
"report_df": "create_final_reports",
"covariate_df": "create_final_covariates"
}
timestamp = None # Initialize timestamp to None
try:
latest_output_folder, timestamp = find_latest_output_folder()
artifacts_folder = os.path.join(latest_output_folder, "artifacts")
for df_name, file_prefix in tables.items():
file_pattern = os.path.join(artifacts_folder, f"{file_prefix}*.parquet")
matching_files = glob.glob(file_pattern)
if matching_files:
latest_file = max(matching_files, key=os.path.getctime)
df = pd.read_parquet(latest_file)
globals()[df_name] = df
logging.info(f"Successfully loaded {df_name} from {latest_file}")
else:
logging.warning(f"No matching file found for {df_name} in {artifacts_folder}. Initializing as an empty DataFrame.")
globals()[df_name] = pd.DataFrame()
except Exception as e:
logging.error(f"Error initializing data: {str(e)}")
for df_name in tables.keys():
globals()[df_name] = pd.DataFrame()
return timestamp
# Call initialize_data and store the timestamp
current_timestamp = initialize_data()
def find_available_port(start_port, max_attempts=100):
for port in range(start_port, start_port + max_attempts):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('', port))
return port
except OSError:
continue
raise IOError("No free ports found")
def start_api_server(port):
subprocess.Popen([sys.executable, "api_server.py", "--port", str(port)])
def wait_for_api_server(port):
max_retries = 30
for _ in range(max_retries):
try:
response = requests.get(f"http://localhost:{port}")
if response.status_code == 200:
print(f"API server is up and running on port {port}")
return
else:
print(f"Unexpected response from API server: {response.status_code}")
except requests.ConnectionError:
time.sleep(1)
print("Failed to connect to API server")
def load_settings():
try:
with open("indexing/settings.yaml", "r") as f:
return yaml.safe_load(f) or {}
except FileNotFoundError:
return {}
def update_setting(key, value):
settings = load_settings()
try:
settings[key] = json.loads(value)
except json.JSONDecodeError:
settings[key] = value
try:
with open("indexing/settings.yaml", "w") as f:
yaml.dump(settings, f, default_flow_style=False)
return f"Setting '{key}' updated successfully"
except Exception as e:
return f"Error updating setting '{key}': {str(e)}"
def create_setting_component(key, value):
with gr.Accordion(key, open=False):
if isinstance(value, (dict, list)):
value_str = json.dumps(value, indent=2)
lines = value_str.count('\n') + 1
else:
value_str = str(value)
lines = 1
text_area = gr.TextArea(value=value_str, label="Value", lines=lines, max_lines=20)
update_btn = gr.Button("Update", variant="primary")
status = gr.Textbox(label="Status", visible=False)
update_btn.click(
fn=update_setting,
inputs=[gr.Textbox(value=key, visible=False), text_area],
outputs=[status]
).then(
fn=lambda: gr.update(visible=True),
outputs=[status]
)
def get_openai_client():
return OpenAI(
base_url=os.getenv("LLM_API_BASE"),
api_key=os.getenv("LLM_API_KEY"),
llm_model = os.getenv("LLM_MODEL")
)
async def chat_with_openai(messages, model, temperature, max_tokens, api_base):
client = AsyncOpenAI(
base_url=api_base,
api_key=os.getenv("LLM_API_KEY")
)
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
logging.error(f"Error in chat_with_openai: {str(e)}")
return f"An error occurred: {str(e)}"
return f"Error: {str(e)}"
def chat_with_llm(query, history, system_message, temperature, max_tokens, model, api_base):
try:
messages = [{"role": "system", "content": system_message}]
for item in history:
if isinstance(item, tuple) and len(item) == 2:
human, ai = item
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": ai})
messages.append({"role": "user", "content": query})
logging.info(f"Sending chat request to {api_base} with model {model}")
client = OpenAI(base_url=api_base, api_key=os.getenv("LLM_API_KEY", "dummy-key"))
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
logging.error(f"Error in chat_with_llm: {str(e)}")
logging.error(f"Attempted with model: {model}, api_base: {api_base}")
raise RuntimeError(f"Chat request failed: {str(e)}")
def run_graphrag_query(cli_args):
try:
command = ' '.join(cli_args)
logging.info(f"Executing command: {command}")
result = subprocess.run(cli_args, capture_output=True, text=True, check=True)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
logging.error(f"Error running GraphRAG query: {e}")
logging.error(f"Command output (stdout): {e.stdout}")
logging.error(f"Command output (stderr): {e.stderr}")
raise RuntimeError(f"GraphRAG query failed: {e.stderr}")
def parse_query_response(response: str):
try:
# Split the response into metadata and content
parts = response.split("\n\n", 1)
if len(parts) < 2:
return response # Return original response if it doesn't contain metadata
metadata_str, content = parts
metadata = json.loads(metadata_str)
# Extract relevant information from metadata
query_type = metadata.get("query_type", "Unknown")
execution_time = metadata.get("execution_time", "N/A")
tokens_used = metadata.get("tokens_used", "N/A")
# Remove unwanted lines from the content
content_lines = content.split('\n')
filtered_content = '\n'.join([line for line in content_lines if not line.startswith("INFO:") and not line.startswith("creating llm client")])
# Format the parsed response
parsed_response = f"""
Query Type: {query_type}
Execution Time: {execution_time} seconds
Tokens Used: {tokens_used}
{filtered_content.strip()}
"""
return parsed_response
except Exception as e:
print(f"Error parsing query response: {str(e)}")
return response
def send_message(query_type, query, history, system_message, temperature, max_tokens, preset, community_level, response_type, custom_cli_args, selected_folder):
try:
if query_type in ["global", "local"]:
cli_args = construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder)
logging.info(f"Executing {query_type} search with command: {' '.join(cli_args)}")
result = run_graphrag_query(cli_args)
parsed_result = parse_query_response(result)
logging.info(f"Parsed query result: {parsed_result}")
else: # Direct chat
llm_model = os.getenv("LLM_MODEL")
api_base = os.getenv("LLM_API_BASE")
logging.info(f"Executing direct chat with model: {llm_model}")
try:
result = chat_with_llm(query, history, system_message, temperature, max_tokens, llm_model, api_base)
parsed_result = result # No parsing needed for direct chat
logging.info(f"Direct chat result: {parsed_result[:100]}...") # Log first 100 chars of result
except Exception as chat_error:
logging.error(f"Error in chat_with_llm: {str(chat_error)}")
raise RuntimeError(f"Direct chat failed: {str(chat_error)}")
history.append((query, parsed_result))
except Exception as e:
error_message = f"An error occurred: {str(e)}"
logging.error(error_message)
logging.exception("Exception details:")
history.append((query, error_message))
return history, gr.update(value=""), update_logs()
def construct_cli_args(query_type, preset, community_level, response_type, custom_cli_args, query, selected_folder):
if not selected_folder:
raise ValueError("No folder selected. Please select an output folder before querying.")
artifacts_folder = os.path.join("./indexing/output", selected_folder, "artifacts")
if not os.path.exists(artifacts_folder):
raise ValueError(f"Artifacts folder not found in {artifacts_folder}")
base_args = [
"python", "-m", "graphrag.query",
"--data", artifacts_folder,
"--method", query_type,
]
# Apply preset configurations
if preset.startswith("Default"):
base_args.extend(["--community_level", "2", "--response_type", "Multiple Paragraphs"])
elif preset.startswith("Detailed"):
base_args.extend(["--community_level", "4", "--response_type", "Multi-Page Report"])
elif preset.startswith("Quick"):
base_args.extend(["--community_level", "1", "--response_type", "Single Paragraph"])
elif preset.startswith("Bullet"):
base_args.extend(["--community_level", "2", "--response_type", "List of 3-7 Points"])
elif preset.startswith("Comprehensive"):
base_args.extend(["--community_level", "5", "--response_type", "Multi-Page Report"])
elif preset.startswith("High-Level"):
base_args.extend(["--community_level", "1", "--response_type", "Single Page"])
elif preset.startswith("Focused"):
base_args.extend(["--community_level", "3", "--response_type", "Multiple Paragraphs"])
elif preset == "Custom Query":
base_args.extend([
"--community_level", str(community_level),
"--response_type", f'"{response_type}"',
])
if custom_cli_args:
base_args.extend(custom_cli_args.split())
# Add the query at the end
base_args.append(query)
return base_args
def upload_file(file):
if file is not None:
input_dir = os.path.join("indexing", "input")
os.makedirs(input_dir, exist_ok=True)
# Get the original filename from the uploaded file
original_filename = file.name
# Create the destination path
destination_path = os.path.join(input_dir, os.path.basename(original_filename))
# Move the uploaded file to the destination path
shutil.move(file.name, destination_path)
logging.info(f"File uploaded and moved to: {destination_path}")
status = f"File uploaded: {os.path.basename(original_filename)}"
else:
status = "No file uploaded"
# Get the updated file list
updated_file_list = [f["path"] for f in list_input_files()]
return status, gr.update(choices=updated_file_list), update_logs()
def list_input_files():
input_dir = os.path.join("indexing", "input")
files = []
if os.path.exists(input_dir):
files = os.listdir(input_dir)
return [{"name": f, "path": os.path.join(input_dir, f)} for f in files]
def delete_file(file_path):
try:
os.remove(file_path)
logging.info(f"File deleted: {file_path}")
status = f"File deleted: {os.path.basename(file_path)}"
except Exception as e:
logging.error(f"Error deleting file: {str(e)}")
status = f"Error deleting file: {str(e)}"
# Get the updated file list
updated_file_list = [f["path"] for f in list_input_files()]
return status, gr.update(choices=updated_file_list), update_logs()
def read_file_content(file_path):
try:
if file_path.endswith('.parquet'):
df = pd.read_parquet(file_path)
# Get basic information about the DataFrame
info = f"Parquet File: {os.path.basename(file_path)}\n"
info += f"Rows: {len(df)}, Columns: {len(df.columns)}\n\n"
info += "Column Names:\n" + "\n".join(df.columns) + "\n\n"
# Display first few rows
info += "First 5 rows:\n"
info += df.head().to_string() + "\n\n"
# Display basic statistics
info += "Basic Statistics:\n"
info += df.describe().to_string()
return info
else:
with open(file_path, 'r', encoding='utf-8', errors='replace') as file:
content = file.read()
return content
except Exception as e:
logging.error(f"Error reading file: {str(e)}")
return f"Error reading file: {str(e)}"
def save_file_content(file_path, content):
try:
with open(file_path, 'w') as file:
file.write(content)
logging.info(f"File saved: {file_path}")
status = f"File saved: {os.path.basename(file_path)}"
except Exception as e:
logging.error(f"Error saving file: {str(e)}")
status = f"Error saving file: {str(e)}"
return status, update_logs()
def manage_data():
db = lancedb.connect("./indexing/lancedb")
tables = db.table_names()
table_info = ""
if tables:
table = db[tables[0]]
table_info = f"Table: {tables[0]}\nSchema: {table.schema}"
input_files = list_input_files()
return {
"database_info": f"Tables: {', '.join(tables)}\n\n{table_info}",
"input_files": input_files
}
def find_latest_graph_file(root_dir):
pattern = os.path.join(root_dir, "output", "*", "artifacts", "*.graphml")
graph_files = glob.glob(pattern)
if not graph_files:
# If no files found, try excluding .DS_Store
output_dir = os.path.join(root_dir, "output")
run_dirs = [d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d)) and d != ".DS_Store"]
if run_dirs:
latest_run = max(run_dirs)
pattern = os.path.join(root_dir, "output", latest_run, "artifacts", "*.graphml")
graph_files = glob.glob(pattern)
if not graph_files:
return None
# Sort files by modification time, most recent first
latest_file = max(graph_files, key=os.path.getmtime)
return latest_file
def update_visualization(folder_name, file_name, layout_type, node_size, edge_width, node_color_attribute, color_scheme, show_labels, label_size):
root_dir = "./indexing"
if not folder_name or not file_name:
return None, "Please select a folder and a GraphML file."
file_name = file_name.split("] ")[1] if "]" in file_name else file_name # Remove file type prefix
graph_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name)
if not graph_path.endswith('.graphml'):
return None, "Please select a GraphML file for visualization."
try:
# Load the GraphML file
graph = nx.read_graphml(graph_path)
# Create layout based on user selection
if layout_type == "3D Spring":
pos = nx.spring_layout(graph, dim=3, seed=42, k=0.5)
elif layout_type == "2D Spring":
pos = nx.spring_layout(graph, dim=2, seed=42, k=0.5)
else: # Circular
pos = nx.circular_layout(graph)
# Extract node positions
if layout_type == "3D Spring":
x_nodes = [pos[node][0] for node in graph.nodes()]
y_nodes = [pos[node][1] for node in graph.nodes()]
z_nodes = [pos[node][2] for node in graph.nodes()]
else:
x_nodes = [pos[node][0] for node in graph.nodes()]
y_nodes = [pos[node][1] for node in graph.nodes()]
z_nodes = [0] * len(graph.nodes()) # Set all z-coordinates to 0 for 2D layouts
# Extract edge positions
x_edges, y_edges, z_edges = [], [], []
for edge in graph.edges():
x_edges.extend([pos[edge[0]][0], pos[edge[1]][0], None])
y_edges.extend([pos[edge[0]][1], pos[edge[1]][1], None])
if layout_type == "3D Spring":
z_edges.extend([pos[edge[0]][2], pos[edge[1]][2], None])
else:
z_edges.extend([0, 0, None])
# Generate node colors based on user selection
if node_color_attribute == "Degree":
node_colors = [graph.degree(node) for node in graph.nodes()]
else: # Random
node_colors = [random.random() for _ in graph.nodes()]
node_colors = np.array(node_colors)
node_colors = (node_colors - node_colors.min()) / (node_colors.max() - node_colors.min())
# Create the trace for edges
edge_trace = go.Scatter3d(
x=x_edges, y=y_edges, z=z_edges,
mode='lines',
line=dict(color='lightgray', width=edge_width),
hoverinfo='none'
)
# Create the trace for nodes
node_trace = go.Scatter3d(
x=x_nodes, y=y_nodes, z=z_nodes,
mode='markers+text' if show_labels else 'markers',
marker=dict(
size=node_size,
color=node_colors,
colorscale=color_scheme,
colorbar=dict(
title='Node Degree' if node_color_attribute == "Degree" else "Random Value",
thickness=10,
x=1.1,
tickvals=[0, 1],
ticktext=['Low', 'High']
),
line=dict(width=1)
),
text=[node for node in graph.nodes()],
textposition="top center",
textfont=dict(size=label_size, color='black'),
hoverinfo='text'
)
# Create the plot
fig = go.Figure(data=[edge_trace, node_trace])
# Update layout for better visualization
fig.update_layout(
title=f'{layout_type} Graph Visualization: {os.path.basename(graph_path)}',
showlegend=False,
scene=dict(
xaxis=dict(showbackground=False, showticklabels=False, title=''),
yaxis=dict(showbackground=False, showticklabels=False, title=''),
zaxis=dict(showbackground=False, showticklabels=False, title='')
),
margin=dict(l=0, r=0, b=0, t=40),
annotations=[
dict(
showarrow=False,
text=f"Interactive {layout_type} visualization of GraphML data",
xref="paper",
yref="paper",
x=0,
y=0
)
],
autosize=True
)
fig.update_layout(autosize=True)
fig.update_layout(height=600) # Set a fixed height
return fig, f"Graph visualization generated successfully. Using file: {graph_path}"
except Exception as e:
return go.Figure(), f"Error visualizing graph: {str(e)}"
def update_logs():
logs = []
while not log_queue.empty():
logs.append(log_queue.get())
return "\n".join(logs)
def fetch_models(base_url, api_key, service_type):
try:
if service_type.lower() == "ollama":
response = requests.get(f"{base_url}/tags", timeout=10)
else: # OpenAI Compatible
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
response = requests.get(f"{base_url}/models", headers=headers, timeout=10)
logging.info(f"Raw API response: {response.text}")
if response.status_code == 200:
data = response.json()
if service_type.lower() == "ollama":
models = [model.get('name', '') for model in data.get('models', data) if isinstance(model, dict)]
else: # OpenAI Compatible
models = [model.get('id', '') for model in data.get('data', []) if isinstance(model, dict)]
models = [model for model in models if model] # Remove empty strings
if not models:
logging.warning(f"No models found in {service_type} API response")
return ["No models available"]
logging.info(f"Successfully fetched {service_type} models: {models}")
return models
else:
logging.error(f"Error fetching {service_type} models. Status code: {response.status_code}, Response: {response.text}")
return ["Error fetching models"]
except requests.RequestException as e:
logging.error(f"Exception while fetching {service_type} models: {str(e)}")
return ["Error: Connection failed"]
except Exception as e:
logging.error(f"Unexpected error in fetch_models: {str(e)}")
return ["Error: Unexpected issue"]
def update_model_choices(base_url, api_key, service_type, settings_key):
models = fetch_models(base_url, api_key, service_type)
if not models:
logging.warning(f"No models fetched for {service_type}.")
# Get the current model from settings
current_model = settings.get(settings_key, {}).get('llm', {}).get('model')
# If the current model is not in the list, add it
if current_model and current_model not in models:
models.append(current_model)
return gr.update(choices=models, value=current_model if current_model in models else (models[0] if models else None))
def update_llm_model_choices(base_url, api_key, service_type):
return update_model_choices(base_url, api_key, service_type, 'llm')
def update_embeddings_model_choices(base_url, api_key, service_type):
return update_model_choices(base_url, api_key, service_type, 'embeddings')
def update_llm_settings(llm_model, embeddings_model, context_window, system_message, temperature, max_tokens,
llm_api_base, llm_api_key,
embeddings_api_base, embeddings_api_key, embeddings_service_type):
try:
# Update settings.yaml
settings = load_settings()
settings['llm'].update({
"type": "openai", # Always set to "openai" since we removed the radio button
"model": llm_model,
"api_base": llm_api_base,
"api_key": "${GRAPHRAG_API_KEY}",
"temperature": temperature,
"max_tokens": max_tokens,
"provider": "openai_chat" # Always set to "openai_chat"
})
settings['embeddings']['llm'].update({
"type": "openai_embedding", # Always use OpenAIEmbeddingsLLM
"model": embeddings_model,
"api_base": embeddings_api_base,
"api_key": "${GRAPHRAG_API_KEY}",
"provider": embeddings_service_type
})
with open("indexing/settings.yaml", 'w') as f:
yaml.dump(settings, f, default_flow_style=False)
# Update .env file
update_env_file("LLM_API_BASE", llm_api_base)
update_env_file("LLM_API_KEY", llm_api_key)
update_env_file("LLM_MODEL", llm_model)
update_env_file("EMBEDDINGS_API_BASE", embeddings_api_base)
update_env_file("EMBEDDINGS_API_KEY", embeddings_api_key)
update_env_file("EMBEDDINGS_MODEL", embeddings_model)
update_env_file("CONTEXT_WINDOW", str(context_window))
update_env_file("SYSTEM_MESSAGE", system_message)
update_env_file("TEMPERATURE", str(temperature))
update_env_file("MAX_TOKENS", str(max_tokens))
update_env_file("LLM_SERVICE_TYPE", "openai_chat")
update_env_file("EMBEDDINGS_SERVICE_TYPE", embeddings_service_type)
# Reload environment variables
load_dotenv(override=True)
return "LLM and embeddings settings updated successfully in both settings.yaml and .env files."
except Exception as e:
return f"Error updating LLM and embeddings settings: {str(e)}"
def update_env_file(key, value):
env_path = 'indexing/.env'
with open(env_path, 'r') as file:
lines = file.readlines()
updated = False
for i, line in enumerate(lines):
if line.startswith(f"{key}="):
lines[i] = f"{key}={value}\n"
updated = True
break
if not updated:
lines.append(f"{key}={value}\n")
with open(env_path, 'w') as file:
file.writelines(lines)
custom_css = """
html, body {
margin: 0;
padding: 0;
height: 100vh;
overflow: hidden;
}
.gradio-container {
margin: 0 !important;
padding: 0 !important;
width: 100vw !important;
max-width: 100vw !important;
height: 100vh !important;
max-height: 100vh !important;
overflow: auto;
display: flex;
flex-direction: column;
}
#main-container {
flex: 1;
display: flex;
overflow: hidden;
}
#left-column, #right-column {
height: 100%;
overflow-y: auto;
padding: 10px;
}
#left-column {
flex: 1;
}
#right-column {
flex: 2;
display: flex;
flex-direction: column;
}
#chat-container {
flex: 0 0 auto; /* Don't allow this to grow */
height: 100%;
display: flex;
flex-direction: column;
overflow: hidden;
border: 1px solid var(--color-accent);
border-radius: 8px;
padding: 10px;
overflow-y: auto;
}
#chatbot {
overflow-y: hidden;
height: 100%;
}
#chat-input-row {
margin-top: 10px;
}
#visualization-plot {
width: 100%;
aspect-ratio: 1 / 1;
max-height: 600px; /* Adjust this value as needed */
}
#vis-controls-row {
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 10px;
}
#vis-controls-row > * {
flex: 1;
margin: 0 5px;
}
#vis-status {
margin-top: 10px;
}
/* Chat input styling */
#chat-input-row {
display: flex;
flex-direction: column;
}
#chat-input-row > div {
width: 100% !important;
}
#chat-input-row input[type="text"] {
width: 100% !important;
}
/* Adjust padding for all containers */
.gr-box, .gr-form, .gr-panel {
padding: 10px !important;
}
/* Ensure all textboxes and textareas have full height */
.gr-textbox, .gr-textarea {
height: auto !important;
min-height: 100px !important;
}
/* Ensure all dropdowns have full width */
.gr-dropdown {
width: 100% !important;
}
:root {
--color-background: #2C3639;
--color-foreground: #3F4E4F;
--color-accent: #A27B5C;
--color-text: #DCD7C9;
}
body, .gradio-container {
background-color: var(--color-background);
color: var(--color-text);
}
.gr-button {
background-color: var(--color-accent);
color: var(--color-text);
}
.gr-input, .gr-textarea, .gr-dropdown {
background-color: var(--color-foreground);
color: var(--color-text);
border: 1px solid var(--color-accent);
}
.gr-panel {
background-color: var(--color-foreground);
border: 1px solid var(--color-accent);
}
.gr-box {
border-radius: 8px;
margin-bottom: 10px;
background-color: var(--color-foreground);
}
.gr-padded {
padding: 10px;
}
.gr-form {
background-color: var(--color-foreground);
}
.gr-input-label, .gr-radio-label {
color: var(--color-text);
}
.gr-checkbox-label {
color: var(--color-text);
}
.gr-markdown {
color: var(--color-text);
}
.gr-accordion {
background-color: var(--color-foreground);
border: 1px solid var(--color-accent);
}
.gr-accordion-header {
background-color: var(--color-accent);
color: var(--color-text);
}
#visualization-container {
display: flex;
flex-direction: column;
border: 2px solid var(--color-accent);
border-radius: 8px;
margin-top: 20px;
padding: 10px;
background-color: var(--color-foreground);
height: calc(100vh - 300px); /* Adjust this value as needed */
}
#visualization-plot {
width: 100%;
height: 100%;
}
#vis-controls-row {
display: flex;
justify-content: space-between;
align-items: center;
margin-top: 10px;
}
#vis-controls-row > * {
flex: 1;
margin: 0 5px;
}
#vis-status {
margin-top: 10px;
}
#log-container {
background-color: var(--color-foreground);
border: 1px solid var(--color-accent);
border-radius: 8px;
padding: 10px;
margin-top: 20px;
max-height: auto;
overflow-y: auto;
}
.setting-accordion .label-wrap {
cursor: pointer;
}
.setting-accordion .icon {
transition: transform 0.3s ease;
}
.setting-accordion[open] .icon {
transform: rotate(90deg);
}
.gr-form.gr-box {
border: none !important;
background: none !important;
}
.model-params {
border-top: 1px solid var(--color-accent);
margin-top: 10px;
padding-top: 10px;
}
"""
def list_output_files(root_dir):
output_dir = os.path.join(root_dir, "output")
files = []
for root, _, filenames in os.walk(output_dir):
for filename in filenames:
files.append(os.path.join(root, filename))
return files
def update_file_list():
files = list_input_files()
return gr.update(choices=[f["path"] for f in files])
def update_file_content(file_path):
if not file_path:
return ""
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
return content
except Exception as e:
logging.error(f"Error reading file: {str(e)}")
return f"Error reading file: {str(e)}"
def list_output_folders(root_dir):
output_dir = os.path.join(root_dir, "output")
folders = [f for f in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, f))]
return sorted(folders, reverse=True)
def list_folder_contents(folder_path):
contents = []
for item in os.listdir(folder_path):
item_path = os.path.join(folder_path, item)
if os.path.isdir(item_path):
contents.append(f"[DIR] {item}")
else:
_, ext = os.path.splitext(item)
contents.append(f"[{ext[1:].upper()}] {item}")
return contents
def update_output_folder_list():
root_dir = "./"
folders = list_output_folders(root_dir)
return gr.update(choices=folders, value=folders[0] if folders else None)
def update_folder_content_list(folder_name):
root_dir = "./"
if not folder_name:
return gr.update(choices=[])
contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, "artifacts"))
return gr.update(choices=contents)
def handle_content_selection(folder_name, selected_item):
root_dir = "./"
if isinstance(selected_item, list) and selected_item:
selected_item = selected_item[0] # Take the first item if it's a list
if isinstance(selected_item, str) and selected_item.startswith("[DIR]"):
dir_name = selected_item[6:] # Remove "[DIR] " prefix
sub_contents = list_folder_contents(os.path.join(root_dir, "output", folder_name, dir_name))
return gr.update(choices=sub_contents), "", ""
elif isinstance(selected_item, str):
file_name = selected_item.split("] ")[1] if "]" in selected_item else selected_item # Remove file type prefix if present
file_path = os.path.join(root_dir, "output", folder_name, "artifacts", file_name)
file_size = os.path.getsize(file_path)
file_type = os.path.splitext(file_name)[1]
file_info = f"File: {file_name}\nSize: {file_size} bytes\nType: {file_type}"
content = read_file_content(file_path)
return gr.update(), file_info, content
else:
return gr.update(), "", ""
def initialize_selected_folder(folder_name):
root_dir = "./"
if not folder_name:
return "Please select a folder first.", gr.update(choices=[])
folder_path = os.path.join(root_dir, "output", folder_name, "artifacts")
if not os.path.exists(folder_path):
return f"Artifacts folder not found in '{folder_name}'.", gr.update(choices=[])
contents = list_folder_contents(folder_path)
return f"Folder '{folder_name}/artifacts' initialized with {len(contents)} items.", gr.update(choices=contents)
settings = load_settings()
default_model = settings['llm']['model']
cli_args = gr.State({})
stop_indexing = threading.Event()
indexing_thread = None
def start_indexing(*args):
global indexing_thread, stop_indexing
stop_indexing = threading.Event() # Reset the stop_indexing event
indexing_thread = threading.Thread(target=run_indexing, args=args)
indexing_thread.start()
return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False)
def stop_indexing_process():
global indexing_thread
logging.info("Stop indexing requested")
stop_indexing.set()
if indexing_thread and indexing_thread.is_alive():
logging.info("Waiting for indexing thread to finish")
indexing_thread.join(timeout=10)
logging.info("Indexing thread finished" if not indexing_thread.is_alive() else "Indexing thread did not finish within timeout")
indexing_thread = None # Reset the thread
return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True)
def refresh_indexing():
global indexing_thread, stop_indexing
if indexing_thread and indexing_thread.is_alive():
logging.info("Cannot refresh: Indexing is still running")
return gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), "Cannot refresh: Indexing is still running"
else:
stop_indexing = threading.Event() # Reset the stop_indexing event
indexing_thread = None # Reset the thread
return gr.update(interactive=True), gr.update(interactive=False), gr.update(interactive=True), "Indexing process refreshed. You can start indexing again."
def run_indexing(root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_args):
cmd = ["python", "-m", "graphrag.index", "--root", "./indexing"]
# Add custom CLI arguments
if custom_args:
cmd.extend(custom_args.split())
logging.info(f"Executing command: {' '.join(cmd)}")
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, encoding='utf-8', universal_newlines=True)
output = []
progress_value = 0
iterations_completed = 0
while True:
if stop_indexing.is_set():
process.terminate()
process.wait(timeout=5)
if process.poll() is None:
process.kill()
return ("\n".join(output + ["Indexing stopped by user."]),
"Indexing stopped.",
100,
gr.update(interactive=True),
gr.update(interactive=False),
gr.update(interactive=True),
str(iterations_completed))
try:
line = process.stdout.readline()
if not line and process.poll() is not None:
break
if line:
line = line.strip()
output.append(line)
if "Processing file" in line:
progress_value += 1
iterations_completed += 1
elif "Indexing completed" in line:
progress_value = 100
elif "ERROR" in line:
line = f"🚨 ERROR: {line}"
yield ("\n".join(output),
line,
progress_value,
gr.update(interactive=False),
gr.update(interactive=True),
gr.update(interactive=False),
str(iterations_completed))
except Exception as e:
logging.error(f"Error during indexing: {str(e)}")
return ("\n".join(output + [f"Error: {str(e)}"]),
"Error occurred during indexing.",
100,
gr.update(interactive=True),
gr.update(interactive=False),
gr.update(interactive=True),
str(iterations_completed))
if process.returncode != 0 and not stop_indexing.is_set():
final_output = "\n".join(output + [f"Error: Process exited with return code {process.returncode}"])
final_progress = "Indexing failed. Check output for details."
else:
final_output = "\n".join(output)
final_progress = "Indexing completed successfully!"
return (final_output,
final_progress,
100,
gr.update(interactive=True),
gr.update(interactive=False),
gr.update(interactive=True),
str(iterations_completed))
global_vector_store_wrapper = None
def create_gradio_interface():
global global_vector_store_wrapper
llm_models, embeddings_models, llm_service_type, embeddings_service_type, llm_api_base, embeddings_api_base, text_embedder = initialize_models()
settings = load_settings()
log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False, visible=False)
with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo:
gr.Markdown("# GraphRAG Local UI", elem_id="title")
with gr.Row(elem_id="main-container"):
with gr.Column(scale=1, elem_id="left-column"):
with gr.Tabs():
with gr.TabItem("Data Management"):
with gr.Accordion("File Upload (.txt)", open=True):
file_upload = gr.File(label="Upload .txt File", file_types=[".txt"])
upload_btn = gr.Button("Upload File", variant="primary")
upload_output = gr.Textbox(label="Upload Status", visible=False)
with gr.Accordion("File Management", open=True):
file_list = gr.Dropdown(label="Select File", choices=[], interactive=True)
refresh_btn = gr.Button("Refresh File List", variant="secondary")
file_content = gr.TextArea(label="File Content", lines=10)
with gr.Row():
delete_btn = gr.Button("Delete Selected File", variant="stop")
save_btn = gr.Button("Save Changes", variant="primary")
operation_status = gr.Textbox(label="Operation Status", visible=False)
with gr.TabItem("Indexing"):
root_dir = gr.Textbox(label="Root Directory", value="./")
config_file = gr.File(label="Config File (optional)")
with gr.Row():
verbose = gr.Checkbox(label="Verbose", value=True)
nocache = gr.Checkbox(label="No Cache", value=True)
with gr.Row():
resume = gr.Textbox(label="Resume Timestamp (optional)")
reporter = gr.Dropdown(label="Reporter", choices=["rich", "print", "none"], value=None)
with gr.Row():
emit_formats = gr.CheckboxGroup(label="Emit Formats", choices=["json", "csv", "parquet"], value=None)
with gr.Row():
run_index_button = gr.Button("Run Indexing")
stop_index_button = gr.Button("Stop Indexing", variant="stop")
refresh_index_button = gr.Button("Refresh Indexing", variant="secondary")
with gr.Accordion("Custom CLI Arguments", open=True):
custom_cli_args = gr.Textbox(
label="Custom CLI Arguments",
placeholder="--arg1 value1 --arg2 value2",
lines=3
)
cli_guide = gr.Markdown(
textwrap.dedent("""
### CLI Argument Key Guide:
- `--root <path>`: Set the root directory for the project
- `--config <path>`: Specify a custom configuration file
- `--verbose`: Enable verbose output
- `--nocache`: Disable caching
- `--resume <timestamp>`: Resume from a specific timestamp
- `--reporter <type>`: Set the reporter type (rich, print, none)
- `--emit <formats>`: Specify output formats (json, csv, parquet)
Example: `--verbose --nocache --emit json,csv`
""")
)
index_output = gr.Textbox(label="Indexing Output", lines=20, max_lines=30)
index_progress = gr.Textbox(label="Indexing Progress", lines=3)
iterations_completed = gr.Textbox(label="Iterations Completed", value="0")
refresh_status = gr.Textbox(label="Refresh Status", visible=True)
run_index_button.click(
fn=start_indexing,
inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args],
outputs=[run_index_button, stop_index_button, refresh_index_button]
).then(
fn=run_indexing,
inputs=[root_dir, config_file, verbose, nocache, resume, reporter, emit_formats, custom_cli_args],
outputs=[index_output, index_progress, run_index_button, stop_index_button, refresh_index_button, iterations_completed]
)
stop_index_button.click(
fn=stop_indexing_process,
outputs=[run_index_button, stop_index_button, refresh_index_button]
)
refresh_index_button.click(
fn=refresh_indexing,
outputs=[run_index_button, stop_index_button, refresh_index_button, refresh_status]
)
with gr.TabItem("Indexing Outputs/Visuals"):
output_folder_list = gr.Dropdown(label="Select Output Folder (Select GraphML File to Visualize)", choices=list_output_folders("./indexing"), interactive=True)
refresh_folder_btn = gr.Button("Refresh Folder List", variant="secondary")
initialize_folder_btn = gr.Button("Initialize Selected Folder", variant="primary")
folder_content_list = gr.Dropdown(label="Select File or Directory", choices=[], interactive=True)
file_info = gr.Textbox(label="File Information", interactive=False)
output_content = gr.TextArea(label="File Content", lines=20, interactive=False)
initialization_status = gr.Textbox(label="Initialization Status")
with gr.TabItem("LLM Settings"):
llm_base_url = gr.Textbox(label="LLM API Base URL", value=os.getenv("LLM_API_BASE"))
llm_api_key = gr.Textbox(label="LLM API Key", value=os.getenv("LLM_API_KEY"), type="password")
llm_service_type = gr.Radio(
label="LLM Service Type",
choices=["openai", "ollama"],
value="openai",
visible=False # Hide this if you want to always use OpenAI
)
llm_model_dropdown = gr.Dropdown(
label="LLM Model",
choices=[], # Start with an empty list
value=settings['llm'].get('model'),
allow_custom_value=True
)
refresh_llm_models_btn = gr.Button("Refresh LLM Models", variant="secondary")
embeddings_base_url = gr.Textbox(label="Embeddings API Base URL", value=os.getenv("EMBEDDINGS_API_BASE"))
embeddings_api_key = gr.Textbox(label="Embeddings API Key", value=os.getenv("EMBEDDINGS_API_KEY"), type="password")
embeddings_service_type = gr.Radio(
label="Embeddings Service Type",
choices=["openai", "ollama"],
value=settings.get('embeddings', {}).get('llm', {}).get('type', 'openai'),
visible=False,
)
embeddings_model_dropdown = gr.Dropdown(
label="Embeddings Model",
choices=[],
value=settings.get('embeddings', {}).get('llm', {}).get('model'),
allow_custom_value=True
)
refresh_embeddings_models_btn = gr.Button("Refresh Embedding Models", variant="secondary")
system_message = gr.Textbox(
lines=5,
label="System Message",
value=os.getenv("SYSTEM_MESSAGE", "You are a helpful AI assistant.")
)
context_window = gr.Slider(
label="Context Window",
minimum=512,
maximum=32768,
step=512,
value=int(os.getenv("CONTEXT_WINDOW", 4096))
)
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
step=0.1,
value=float(settings['llm'].get('TEMPERATURE', 0.5))
)
max_tokens = gr.Slider(
label="Max Tokens",
minimum=1,
maximum=8192,
step=1,
value=int(settings['llm'].get('MAX_TOKENS', 1024))
)
update_settings_btn = gr.Button("Update LLM Settings", variant="primary")
llm_settings_status = gr.Textbox(label="Status", interactive=False)
llm_base_url.change(
fn=update_model_choices,
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)],
outputs=llm_model_dropdown
)
# Update Embeddings model choices when service type or base URL changes
embeddings_service_type.change(
fn=update_embeddings_model_choices,
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type],
outputs=embeddings_model_dropdown
)
embeddings_base_url.change(
fn=update_model_choices,
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)],
outputs=embeddings_model_dropdown
)
update_settings_btn.click(
fn=update_llm_settings,
inputs=[
llm_model_dropdown,
embeddings_model_dropdown,
context_window,
system_message,
temperature,
max_tokens,
llm_base_url,
llm_api_key,
embeddings_base_url,
embeddings_api_key,
embeddings_service_type
],
outputs=[llm_settings_status]
)
refresh_llm_models_btn.click(
fn=update_model_choices,
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)],
outputs=[llm_model_dropdown]
).then(
fn=update_logs,
outputs=[log_output]
)
refresh_embeddings_models_btn.click(
fn=update_model_choices,
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)],
outputs=[embeddings_model_dropdown]
).then(
fn=update_logs,
outputs=[log_output]
)
with gr.TabItem("YAML Settings"):
settings = load_settings()
with gr.Group():
for key, value in settings.items():
if key != 'llm':
create_setting_component(key, value)
with gr.Group(elem_id="log-container"):
gr.Markdown("### Logs")
log_output = gr.TextArea(label="Logs", elem_id="log-output", interactive=False)
with gr.Column(scale=2, elem_id="right-column"):
with gr.Group(elem_id="chat-container"):
chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot")
with gr.Row(elem_id="chat-input-row"):
with gr.Column(scale=1):
query_input = gr.Textbox(
label="Input",
placeholder="Enter your query here...",
elem_id="query-input"
)
query_btn = gr.Button("Send Query", variant="primary")
with gr.Accordion("Query Parameters", open=True):
query_type = gr.Radio(
["global", "local", "direct"],
label="Query Type",
value="global",
info="Global: community-based search, Local: entity-based search, Direct: LLM chat"
)
preset_dropdown = gr.Dropdown(
label="Preset Query Options",
choices=[
"Default Global Search",
"Default Local Search",
"Detailed Global Analysis",
"Detailed Local Analysis",
"Quick Global Summary",
"Quick Local Summary",
"Global Bullet Points",
"Local Bullet Points",
"Comprehensive Global Report",
"Comprehensive Local Report",
"High-Level Global Overview",
"High-Level Local Overview",
"Focused Global Insight",
"Focused Local Insight",
"Custom Query"
],
value="Default Global Search",
info="Select a preset or choose 'Custom Query' for manual configuration"
)
selected_folder = gr.Dropdown(
label="Select Index Folder to Chat With",
choices=list_output_folders("./indexing"),
value=None,
interactive=True
)
refresh_folder_btn = gr.Button("Refresh Folders", variant="secondary")
clear_chat_btn = gr.Button("Clear Chat", variant="secondary")
with gr.Group(visible=False) as custom_options:
community_level = gr.Slider(
label="Community Level",
minimum=1,
maximum=10,
value=2,
step=1,
info="Higher values use reports on smaller communities"
)
response_type = gr.Dropdown(
label="Response Type",
choices=[
"Multiple Paragraphs",
"Single Paragraph",
"Single Sentence",
"List of 3-7 Points",
"Single Page",
"Multi-Page Report"
],
value="Multiple Paragraphs",
info="Specify the desired format of the response"
)
custom_cli_args = gr.Textbox(
label="Custom CLI Arguments",
placeholder="--arg1 value1 --arg2 value2",
info="Additional CLI arguments for advanced users"
)
def update_custom_options(preset):
if preset == "Custom Query":
return gr.update(visible=True)
else:
return gr.update(visible=False)
preset_dropdown.change(fn=update_custom_options, inputs=[preset_dropdown], outputs=[custom_options])
with gr.Group(elem_id="visualization-container"):
vis_output = gr.Plot(label="Graph Visualization", elem_id="visualization-plot")
with gr.Row(elem_id="vis-controls-row"):
vis_btn = gr.Button("Visualize Graph", variant="secondary")
# Add new controls for customization
with gr.Accordion("Visualization Settings", open=False):
layout_type = gr.Dropdown(["3D Spring", "2D Spring", "Circular"], label="Layout Type", value="3D Spring")
node_size = gr.Slider(1, 20, 7, label="Node Size", step=1)
edge_width = gr.Slider(0.1, 5, 0.5, label="Edge Width", step=0.1)
node_color_attribute = gr.Dropdown(["Degree", "Random"], label="Node Color Attribute", value="Degree")
color_scheme = gr.Dropdown(["Viridis", "Plasma", "Inferno", "Magma", "Cividis"], label="Color Scheme", value="Viridis")
show_labels = gr.Checkbox(label="Show Node Labels", value=True)
label_size = gr.Slider(5, 20, 10, label="Label Size", step=1)
# Event handlers
upload_btn.click(fn=upload_file, inputs=[file_upload], outputs=[upload_output, file_list, log_output])
refresh_btn.click(fn=update_file_list, outputs=[file_list]).then(
fn=update_logs,
outputs=[log_output]
)
file_list.change(fn=update_file_content, inputs=[file_list], outputs=[file_content]).then(
fn=update_logs,
outputs=[log_output]
)
delete_btn.click(fn=delete_file, inputs=[file_list], outputs=[operation_status, file_list, log_output])
save_btn.click(fn=save_file_content, inputs=[file_list, file_content], outputs=[operation_status, log_output])
refresh_folder_btn.click(
fn=lambda: gr.update(choices=list_output_folders("./indexing")),
outputs=[selected_folder]
)
clear_chat_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, query_input]
)
refresh_folder_btn.click(
fn=update_output_folder_list,
outputs=[output_folder_list]
).then(
fn=update_logs,
outputs=[log_output]
)
output_folder_list.change(
fn=update_folder_content_list,
inputs=[output_folder_list],
outputs=[folder_content_list]
).then(
fn=update_logs,
outputs=[log_output]
)
folder_content_list.change(
fn=handle_content_selection,
inputs=[output_folder_list, folder_content_list],
outputs=[folder_content_list, file_info, output_content]
).then(
fn=update_logs,
outputs=[log_output]
)
initialize_folder_btn.click(
fn=initialize_selected_folder,
inputs=[output_folder_list],
outputs=[initialization_status, folder_content_list]
).then(
fn=update_logs,
outputs=[log_output]
)
vis_btn.click(
fn=update_visualization,
inputs=[
output_folder_list,
folder_content_list,
layout_type,
node_size,
edge_width,
node_color_attribute,
color_scheme,
show_labels,
label_size
],
outputs=[vis_output, gr.Textbox(label="Visualization Status")]
)
query_btn.click(
fn=send_message,
inputs=[
query_type,
query_input,
chatbot,
system_message,
temperature,
max_tokens,
preset_dropdown,
community_level,
response_type,
custom_cli_args,
selected_folder
],
outputs=[chatbot, query_input, log_output]
)
query_input.submit(
fn=send_message,
inputs=[
query_type,
query_input,
chatbot,
system_message,
temperature,
max_tokens,
preset_dropdown,
community_level,
response_type,
custom_cli_args,
selected_folder
],
outputs=[chatbot, query_input, log_output]
)
refresh_llm_models_btn.click(
fn=update_model_choices,
inputs=[llm_base_url, llm_api_key, llm_service_type, gr.Textbox(value='llm', visible=False)],
outputs=[llm_model_dropdown]
)
# Update Embeddings model choices
refresh_embeddings_models_btn.click(
fn=update_model_choices,
inputs=[embeddings_base_url, embeddings_api_key, embeddings_service_type, gr.Textbox(value='embeddings', visible=False)],
outputs=[embeddings_model_dropdown]
)
# Add this JavaScript to enable Shift+Enter functionality
demo.load(js="""
function addShiftEnterListener() {
const queryInput = document.getElementById('query-input');
if (queryInput) {
queryInput.addEventListener('keydown', function(event) {
if (event.key === 'Enter' && event.shiftKey) {
event.preventDefault();
const submitButton = queryInput.closest('.gradio-container').querySelector('button.primary');
if (submitButton) {
submitButton.click();
}
}
});
}
}
document.addEventListener('DOMContentLoaded', addShiftEnterListener);
""")
return demo.queue()
async def main():
api_port = 8088
gradio_port = 7860
print(f"Starting API server on port {api_port}")
start_api_server(api_port)
# Wait for the API server to start in a separate thread
threading.Thread(target=wait_for_api_server, args=(api_port,)).start()
# Create the Gradio app
demo = create_gradio_interface()
print(f"Starting Gradio app on port {gradio_port}")
# Launch the Gradio app
demo.launch(server_port=gradio_port, share=True)
demo = create_gradio_interface()
app = demo.app
if __name__ == "__main__":
initialize_data()
demo.launch(server_port=7860, share=True)