Spaces:
Runtime error
Runtime error
import gradio as gr | |
from rpc import load_index, generate | |
def setup(): | |
""" Downloads the memory mapped vector index (~10GB), installs NGT and loads the index""" | |
import os | |
import requests | |
import subprocess | |
from stream_unzip import stream_unzip | |
DATASET_URL = os.environ.get("DATASET_URL") | |
INDEX_TYPE = os.environ.get("INDEX_TYPE") | |
if not DATASET_URL: | |
raise ValueError("DATASET_URL must be set in the environment") | |
extract_dir = "/dev/shm/rpc-vecdb" | |
os.makedirs(extract_dir, exist_ok=True) | |
response = requests.get(DATASET_URL, stream=True) | |
response.raise_for_status() | |
print("Starting streaming extraction to /dev/shm...") | |
for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)): | |
if isinstance(filename, bytes): | |
filename = filename.decode('utf-8') | |
file_path = os.path.join(extract_dir, filename) | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
with open(file_path, 'wb') as f_out: | |
for chunk in file_iter: | |
f_out.write(chunk) | |
print(f"Extracted: {filename} -> {file_path}") | |
files = os.listdir(extract_dir) | |
files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))] | |
for f in files: print(f) | |
print("Index extracted") | |
if INDEX_TYPE == "ngt": | |
print("Installing NGT...") | |
subprocess.check_call(["bash", "install_ngt.sh"]) | |
print("NGT installed") | |
print("Loading index...") | |
if INDEX_TYPE == "ngt": | |
index_dir = extract_dir + "/index" | |
else: | |
index_dir = extract_dir | |
load_index(index_path=index_dir, idx_type=INDEX_TYPE) | |
print("Index loaded") | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
user_name, | |
ai_name, | |
use_rpc, | |
max_tokens, | |
temperature, | |
): | |
prompt = "<s>" | |
for m in history: | |
prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n" | |
prompt += f"{user_name}: {message.strip()}\n{ai_name}:" | |
response = "" | |
for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens): | |
response += tok | |
yield response | |
print(history, message, response) | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="Jake", label="User name"), | |
gr.Textbox(value="Sarah", label="AI name"), | |
gr.Checkbox( | |
label="Use RPC", | |
info="Compare Normal vs. RPC-Enhanced Model", | |
value=True | |
), | |
gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"), | |
], | |
description="Remember that you are talking with a 5M parameter model trained on allenai/soda, not ChatGPT" | |
) | |
if __name__ == "__main__": | |
setup() | |
demo.launch() | |