|
import argparse |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
from langchain.globals import set_debug |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
from lib.repository import download_github_repo |
|
from lib.loader import load_files |
|
from lib.chain import create_retriever, create_qa_chain |
|
from lib.utils import read_prompt, load_LLM, select_model |
|
from lib.models import MODELS_MAP |
|
|
|
import time |
|
import gradio as gr |
|
|
|
def slow_echo(message, history): |
|
for i in range(len(message)): |
|
time.sleep(0.05) |
|
yield message[: i + 1] |
|
|
|
|
|
|
|
def build(): |
|
with gr.Blocks() as demo: |
|
repo_url = gr.Textbox(label="Repo URL", placeholder="Enter the repository URL here...") |
|
submit_btn = gr.Button("Submit Repo URL") |
|
|
|
user_input = gr.Textbox(label="User Input", placeholder="Enter your question here...") |
|
chat_output = gr.Textbox(label="Chat Output", placeholder="The answer will appear here...") |
|
|
|
|
|
def update_repo_url(new_url): |
|
updated_url = main(new_url) |
|
return updated_url |
|
|
|
def generate_answer(user_input): |
|
answer = qa_chain.invoke(user_input) |
|
print(f"Answer: {answer}") |
|
return answer['output'] |
|
|
|
submit_btn.click(update_repo_url, inputs=repo_url, outputs=repo_url) |
|
user_input_submit_btn = gr.Button("Submit Question") |
|
user_input_submit_btn.click(generate_answer, inputs=user_input, outputs=chat_output) |
|
|
|
demo.launch() |
|
|
|
def main(repo_url): |
|
|
|
model_name = select_model() |
|
model_info = MODELS_MAP[model_name] |
|
repo_name = repo_url.split("/")[-1].replace(".git", "") |
|
|
|
|
|
base_dir = os.path.dirname(os.path.abspath(__file__)) |
|
repo_dir = os.path.join(base_dir, "data", repo_name) |
|
db_dir = os.path.join(base_dir, "data", "db") |
|
prompt_templates_dir = os.path.join(base_dir, "prompt_templates") |
|
|
|
|
|
print(f"Downloading repository from {repo_url}...") |
|
download_github_repo(repo_url, repo_dir) |
|
|
|
|
|
prompts_text = { |
|
"initial_prompt": read_prompt(os.path.join(prompt_templates_dir, 'initial_prompt.txt')), |
|
"evaluation_prompt": read_prompt(os.path.join(prompt_templates_dir, 'evaluation_prompt.txt')), |
|
} |
|
|
|
|
|
print(f"Loading documents from {repo_dir}...") |
|
document_chunks = load_files(repository_path=repo_dir) |
|
print(f"Created chunks length is: {len(document_chunks)}") |
|
|
|
|
|
print(f"Creating retrieval QA chain using {model_name}...") |
|
llm = load_LLM(model_name) |
|
retriever = create_retriever(model_name, db_dir, document_chunks) |
|
global qa_chain |
|
qa_chain = create_qa_chain(llm, retriever, prompts_text) |
|
print(f"Ready to chat!") |
|
return repo_url |
|
|
|
if __name__ == "__main__": |
|
build() |