RPC-Chat / app.py
pedrocas15's picture
Update app.py
fc8cc49 verified
import gradio as gr
from rpc import load_index, generate
def setup():
""" Downloads the memory mapped vector index (~10GB), installs NGT and loads the index"""
import os
import requests
import subprocess
from stream_unzip import stream_unzip
DATASET_URL = os.environ.get("DATASET_URL")
INDEX_TYPE = os.environ.get("INDEX_TYPE")
if not DATASET_URL:
raise ValueError("DATASET_URL must be set in the environment")
extract_dir = "/dev/shm/rpc-vecdb"
os.makedirs(extract_dir, exist_ok=True)
response = requests.get(DATASET_URL, stream=True)
response.raise_for_status()
print("Starting streaming extraction to /dev/shm...")
for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)):
if isinstance(filename, bytes):
filename = filename.decode('utf-8')
file_path = os.path.join(extract_dir, filename)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f_out:
for chunk in file_iter:
f_out.write(chunk)
print(f"Extracted: {filename} -> {file_path}")
files = os.listdir(extract_dir)
files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))]
for f in files: print(f)
print("Index extracted")
if INDEX_TYPE == "ngt":
print("Installing NGT...")
subprocess.check_call(["bash", "install_ngt.sh"])
print("NGT installed")
print("Loading index...")
if INDEX_TYPE == "ngt":
index_dir = extract_dir + "/index"
else:
index_dir = extract_dir
load_index(index_path=index_dir, idx_type=INDEX_TYPE)
print("Index loaded")
def respond(
message,
history: list[tuple[str, str]],
user_name,
ai_name,
use_rpc,
max_tokens,
temperature,
):
prompt = "<s>"
for m in history:
prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n"
prompt += f"{user_name}: {message.strip()}\n{ai_name}:"
response = ""
for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens):
response += tok
yield response
print(history, message, response)
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Jake", label="User name"),
gr.Textbox(value="Sarah", label="AI name"),
gr.Checkbox(
label="Use RPC",
info="Compare Normal vs. RPC-Enhanced Model",
value=True
),
gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"),
],
description="Remember that you are talking with a 5M parameter model trained on allenai/soda, not ChatGPT"
)
if __name__ == "__main__":
setup()
demo.launch()