gitllm / main.py
heaversm's picture
add frontend in gradio
ab2d07e
import argparse
import os
from dotenv import load_dotenv
from langchain.globals import set_debug
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from lib.repository import download_github_repo
from lib.loader import load_files
from lib.chain import create_retriever, create_qa_chain
from lib.utils import read_prompt, load_LLM, select_model
from lib.models import MODELS_MAP
import time
import gradio as gr
def slow_echo(message, history):
for i in range(len(message)):
time.sleep(0.05)
yield message[: i + 1]
# set_debug(True)
def build():
with gr.Blocks() as demo:
repo_url = gr.Textbox(label="Repo URL", placeholder="Enter the repository URL here...")
submit_btn = gr.Button("Submit Repo URL")
user_input = gr.Textbox(label="User Input", placeholder="Enter your question here...")
chat_output = gr.Textbox(label="Chat Output", placeholder="The answer will appear here...")
# add a status textbox
def update_repo_url(new_url):
updated_url = main(new_url)
return updated_url
def generate_answer(user_input):
answer = qa_chain.invoke(user_input)
print(f"Answer: {answer}")
return answer['output']
submit_btn.click(update_repo_url, inputs=repo_url, outputs=repo_url)
user_input_submit_btn = gr.Button("Submit Question")
user_input_submit_btn.click(generate_answer, inputs=user_input, outputs=chat_output)
demo.launch()
def main(repo_url):
# Prompt user to select the model
model_name = select_model()
model_info = MODELS_MAP[model_name]
repo_name = repo_url.split("/")[-1].replace(".git", "")
# Compute the path to the data folder relative to the script's directory
base_dir = os.path.dirname(os.path.abspath(__file__))
repo_dir = os.path.join(base_dir, "data", repo_name)
db_dir = os.path.join(base_dir, "data", "db")
prompt_templates_dir = os.path.join(base_dir, "prompt_templates")
# Download the GitHub repository
print(f"Downloading repository from {repo_url}...")
download_github_repo(repo_url, repo_dir)
# Load prompt templates
prompts_text = {
"initial_prompt": read_prompt(os.path.join(prompt_templates_dir, 'initial_prompt.txt')),
"evaluation_prompt": read_prompt(os.path.join(prompt_templates_dir, 'evaluation_prompt.txt')),
}
# Load documents from the repository
print(f"Loading documents from {repo_dir}...")
document_chunks = load_files(repository_path=repo_dir)
print(f"Created chunks length is: {len(document_chunks)}")
# Create model, retriever
print(f"Creating retrieval QA chain using {model_name}...")
llm = load_LLM(model_name)
retriever = create_retriever(model_name, db_dir, document_chunks)
global qa_chain
qa_chain = create_qa_chain(llm, retriever, prompts_text)
print(f"Ready to chat!")
return repo_url
if __name__ == "__main__":
build()