File size: 3,043 Bytes
449cbf5 ab2d07e 449cbf5 ab2d07e 449cbf5 ab2d07e 449cbf5 ab2d07e 449cbf5 ab2d07e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import argparse
import os
from dotenv import load_dotenv
from langchain.globals import set_debug
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from lib.repository import download_github_repo
from lib.loader import load_files
from lib.chain import create_retriever, create_qa_chain
from lib.utils import read_prompt, load_LLM, select_model
from lib.models import MODELS_MAP
import time
import gradio as gr
def slow_echo(message, history):
for i in range(len(message)):
time.sleep(0.05)
yield message[: i + 1]
# set_debug(True)
def build():
with gr.Blocks() as demo:
repo_url = gr.Textbox(label="Repo URL", placeholder="Enter the repository URL here...")
submit_btn = gr.Button("Submit Repo URL")
user_input = gr.Textbox(label="User Input", placeholder="Enter your question here...")
chat_output = gr.Textbox(label="Chat Output", placeholder="The answer will appear here...")
# add a status textbox
def update_repo_url(new_url):
updated_url = main(new_url)
return updated_url
def generate_answer(user_input):
answer = qa_chain.invoke(user_input)
print(f"Answer: {answer}")
return answer['output']
submit_btn.click(update_repo_url, inputs=repo_url, outputs=repo_url)
user_input_submit_btn = gr.Button("Submit Question")
user_input_submit_btn.click(generate_answer, inputs=user_input, outputs=chat_output)
demo.launch()
def main(repo_url):
# Prompt user to select the model
model_name = select_model()
model_info = MODELS_MAP[model_name]
repo_name = repo_url.split("/")[-1].replace(".git", "")
# Compute the path to the data folder relative to the script's directory
base_dir = os.path.dirname(os.path.abspath(__file__))
repo_dir = os.path.join(base_dir, "data", repo_name)
db_dir = os.path.join(base_dir, "data", "db")
prompt_templates_dir = os.path.join(base_dir, "prompt_templates")
# Download the GitHub repository
print(f"Downloading repository from {repo_url}...")
download_github_repo(repo_url, repo_dir)
# Load prompt templates
prompts_text = {
"initial_prompt": read_prompt(os.path.join(prompt_templates_dir, 'initial_prompt.txt')),
"evaluation_prompt": read_prompt(os.path.join(prompt_templates_dir, 'evaluation_prompt.txt')),
}
# Load documents from the repository
print(f"Loading documents from {repo_dir}...")
document_chunks = load_files(repository_path=repo_dir)
print(f"Created chunks length is: {len(document_chunks)}")
# Create model, retriever
print(f"Creating retrieval QA chain using {model_name}...")
llm = load_LLM(model_name)
retriever = create_retriever(model_name, db_dir, document_chunks)
global qa_chain
qa_chain = create_qa_chain(llm, retriever, prompts_text)
print(f"Ready to chat!")
return repo_url
if __name__ == "__main__":
build() |