Spaces:
Runtime error
Runtime error
File size: 2,926 Bytes
754bd10 f8ab6ea 754bd10 f8ab6ea fb2c227 f8ab6ea fb2c227 f8ab6ea fb2c227 f8ab6ea fb2c227 f8ab6ea 754bd10 f8ab6ea 754bd10 f8ab6ea f579840 71f9478 abcef1f 7d65953 5da7f16 754bd10 90b65f8 754bd10 f8ab6ea 754bd10 22974fb f8ab6ea 754bd10 f8ab6ea 754bd10 fc8cc49 754bd10 bce4284 754bd10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
from rpc import load_index, generate
def setup():
""" Downloads the memory mapped vector index (~10GB), installs NGT and loads the index"""
import os
import requests
import subprocess
from stream_unzip import stream_unzip
DATASET_URL = os.environ.get("DATASET_URL")
INDEX_TYPE = os.environ.get("INDEX_TYPE")
if not DATASET_URL:
raise ValueError("DATASET_URL must be set in the environment")
extract_dir = "/dev/shm/rpc-vecdb"
os.makedirs(extract_dir, exist_ok=True)
response = requests.get(DATASET_URL, stream=True)
response.raise_for_status()
print("Starting streaming extraction to /dev/shm...")
for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)):
if isinstance(filename, bytes):
filename = filename.decode('utf-8')
file_path = os.path.join(extract_dir, filename)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, 'wb') as f_out:
for chunk in file_iter:
f_out.write(chunk)
print(f"Extracted: {filename} -> {file_path}")
files = os.listdir(extract_dir)
files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))]
for f in files: print(f)
print("Index extracted")
if INDEX_TYPE == "ngt":
print("Installing NGT...")
subprocess.check_call(["bash", "install_ngt.sh"])
print("NGT installed")
print("Loading index...")
if INDEX_TYPE == "ngt":
index_dir = extract_dir + "/index"
else:
index_dir = extract_dir
load_index(index_path=index_dir, idx_type=INDEX_TYPE)
print("Index loaded")
def respond(
message,
history: list[tuple[str, str]],
user_name,
ai_name,
use_rpc,
max_tokens,
temperature,
):
prompt = "<s>"
for m in history:
prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n"
prompt += f"{user_name}: {message.strip()}\n{ai_name}:"
response = ""
for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens):
response += tok
yield response
print(history, message, response)
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="Jake", label="User name"),
gr.Textbox(value="Sarah", label="AI name"),
gr.Checkbox(
label="Use RPC",
info="Compare Normal vs. RPC-Enhanced Model",
value=True
),
gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"),
],
description="Remember that you are talking with a 5M parameter model trained on allenai/soda, not ChatGPT"
)
if __name__ == "__main__":
setup()
demo.launch()
|