File size: 2,571 Bytes
754bd10
f8ab6ea
754bd10
f8ab6ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2c5ff2
f8ab6ea
 
754bd10
 
 
 
 
f8ab6ea
 
 
754bd10
 
 
f8ab6ea
2f324c8
f8ab6ea
f579840
71f9478
abcef1f
 
7d65953
5da7f16
754bd10
 
 
f8ab6ea
754bd10
 
 
f8ab6ea
 
 
 
 
 
754bd10
f8ab6ea
 
754bd10
 
 
 
 
bce4284
754bd10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from rpc import load_index, generate

def setup():
    """ Downloads the memory mapped vector index (~10GB), installs NGT and loads the index"""
    import os
    import requests
    import subprocess
    from stream_unzip import stream_unzip
    
    DATASET_URL = os.environ.get("DATASET_URL")
    if not DATASET_URL:
        raise ValueError("DATASET_URL must be set in the environment")
    extract_dir = "/dev/shm/rpc-vecdb"
    os.makedirs(extract_dir, exist_ok=True)
    response = requests.get(DATASET_URL, stream=True)
    response.raise_for_status()
    
    print("Starting streaming extraction to /dev/shm...")
    for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)):
        if isinstance(filename, bytes):
            filename = filename.decode('utf-8')
        file_path = os.path.join(extract_dir, filename)
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, 'wb') as f_out:
            for chunk in file_iter:
                f_out.write(chunk)
        print(f"Extracted: {filename} -> {file_path}")
    
    files = os.listdir(extract_dir)
    files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))]
    for f in files: print(f)
    print("Index extracted")

    print("Installing NGT...")
    subprocess.check_call(["bash", "install_ngt.sh"])
    print("NGT installed")

    print("Loading index...")
    load_index(extract_dir + "/index")
    print("Index loaded")
    


def respond(
    message,
    history: list[tuple[str, str]],
    user_name,
    ai_name,
    use_rpc,
    max_tokens,
    temperature,
):
    prompt = "<s>"
    print(history)
    for m in history:
        prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n"
    prompt += f"{user_name}: {message.strip()}\n{ai_name}:"

    response = ""
    for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens):
        response += tok
        yield response



demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Peter", label="User name"),
        gr.Textbox(value="Sarah", label="Assistant name"),
        gr.Checkbox(
            label="Use RPC",
            info="Compare Normal vs. RPC-Enhanced Model",
            value=True
        ),
        gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"),
    ],
)


if __name__ == "__main__":
    setup()
    demo.launch()