import gradio as gr from rpc import load_index, generate def setup(): """ Downloads the memory mapped vector index (~10GB), installs NGT and loads the index""" import os import requests import subprocess from stream_unzip import stream_unzip DATASET_URL = os.environ.get("DATASET_URL") INDEX_TYPE = os.environ.get("INDEX_TYPE") if not DATASET_URL: raise ValueError("DATASET_URL must be set in the environment") extract_dir = "/dev/shm/rpc-vecdb" os.makedirs(extract_dir, exist_ok=True) response = requests.get(DATASET_URL, stream=True) response.raise_for_status() print("Starting streaming extraction to /dev/shm...") for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)): if isinstance(filename, bytes): filename = filename.decode('utf-8') file_path = os.path.join(extract_dir, filename) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, 'wb') as f_out: for chunk in file_iter: f_out.write(chunk) print(f"Extracted: {filename} -> {file_path}") files = os.listdir(extract_dir) files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))] for f in files: print(f) print("Index extracted") if INDEX_TYPE == "ngt": print("Installing NGT...") subprocess.check_call(["bash", "install_ngt.sh"]) print("NGT installed") print("Loading index...") if INDEX_TYPE == "ngt": index_dir = extract_dir + "/index" else: index_dir = extract_dir load_index(index_path=index_dir, idx_type=INDEX_TYPE) print("Index loaded") def respond( message, history: list[tuple[str, str]], user_name, ai_name, use_rpc, max_tokens, temperature, ): prompt = "" for m in history: prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n" prompt += f"{user_name}: {message.strip()}\n{ai_name}:" response = "" for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens): response += tok yield response print(history, message, response) demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="Jake", label="User name"), gr.Textbox(value="Sarah", label="AI name"), gr.Checkbox( label="Use RPC", info="Compare Normal vs. RPC-Enhanced Model", value=True ), gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"), ], description="Remember that you are talking with a 5M parameter model trained on allenai/soda, not ChatGPT" ) if __name__ == "__main__": setup() demo.launch()