File size: 2,926 Bytes
754bd10
f8ab6ea
754bd10
f8ab6ea
 
 
 
 
 
 
 
fb2c227
 
f8ab6ea
 
fb2c227
f8ab6ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb2c227
 
 
 
f8ab6ea
 
fb2c227
 
 
 
 
f8ab6ea
 
754bd10
 
 
 
 
f8ab6ea
 
 
754bd10
 
 
f8ab6ea
 
f579840
71f9478
abcef1f
 
7d65953
5da7f16
754bd10
90b65f8
754bd10
 
f8ab6ea
754bd10
 
 
22974fb
 
f8ab6ea
 
 
 
754bd10
f8ab6ea
 
754bd10
fc8cc49
754bd10
 
 
 
bce4284
754bd10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
from rpc import load_index, generate

def setup():
    """ Downloads the memory mapped vector index (~10GB), installs NGT and loads the index"""
    import os
    import requests
    import subprocess
    from stream_unzip import stream_unzip
    
    DATASET_URL = os.environ.get("DATASET_URL")
    INDEX_TYPE  = os.environ.get("INDEX_TYPE")
    
    if not DATASET_URL:
        raise ValueError("DATASET_URL must be set in the environment")
        
    extract_dir = "/dev/shm/rpc-vecdb"
    os.makedirs(extract_dir, exist_ok=True)
    response = requests.get(DATASET_URL, stream=True)
    response.raise_for_status()
    
    print("Starting streaming extraction to /dev/shm...")
    for filename, filesize, file_iter in stream_unzip(response.iter_content(chunk_size=8192)):
        if isinstance(filename, bytes):
            filename = filename.decode('utf-8')
        file_path = os.path.join(extract_dir, filename)
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, 'wb') as f_out:
            for chunk in file_iter:
                f_out.write(chunk)
        print(f"Extracted: {filename} -> {file_path}")
    
    files = os.listdir(extract_dir)
    files = [f for f in files if os.path.isfile(os.path.join(extract_dir, f))]
    for f in files: print(f)
    print("Index extracted")

    if INDEX_TYPE == "ngt":
        print("Installing NGT...")
        subprocess.check_call(["bash", "install_ngt.sh"])
        print("NGT installed")

    print("Loading index...")
    if INDEX_TYPE == "ngt":
        index_dir = extract_dir + "/index"
    else:
        index_dir = extract_dir
    load_index(index_path=index_dir, idx_type=INDEX_TYPE)
    print("Index loaded")
    


def respond(
    message,
    history: list[tuple[str, str]],
    user_name,
    ai_name,
    use_rpc,
    max_tokens,
    temperature,
):
    prompt = "<s>"
    for m in history:
        prompt += f"{user_name}: {m[0].strip()}\n{ai_name}: {m[1].strip()}\n"
    prompt += f"{user_name}: {message.strip()}\n{ai_name}:"

    response = ""
    for tok in generate(prompt, use_rpc=use_rpc, max_tokens=max_tokens):
        response += tok
        yield response
    print(history, message, response)



demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="Jake", label="User name"),
        gr.Textbox(value="Sarah", label="AI name"),
        gr.Checkbox(
            label="Use RPC",
            info="Compare Normal vs. RPC-Enhanced Model",
            value=True
        ),
        gr.Slider(minimum=1, maximum=320, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=3.0, value=0.2, step=0.1, label="Temperature (only used without RPC)"),
    ],
    description="Remember that you are talking with a 5M parameter model trained on allenai/soda, not ChatGPT"
)


if __name__ == "__main__":
    setup()
    demo.launch()