File size: 1,640 Bytes
cd607b2
 
eac37df
cd607b2
eac37df
 
 
cd607b2
 
7b856a8
69deff6
 
7b856a8
 
4e3dc76
7b856a8
 
 
 
 
 
 
 
4e3dc76
 
 
 
 
7b856a8
4e3dc76
 
 
 
7b856a8
4e3dc76
 
 
7b856a8
 
69deff6
 
7b856a8
4e3dc76
 
 
 
 
 
 
 
7b856a8
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# + tags=["hide_inp"]
desc = """
### Book QA

Chain that does question answering with Hugging Face embeddings. [[Code](https://github.com/srush/MiniChain/blob/main/examples/gatsby.py)]

(Adapted from the [LlamaIndex example](https://github.com/jerryjliu/gpt_index/blob/main/examples/gatsby/TestGatsby.ipynb).)
"""
# -

# $

import datasets
import numpy as np
from minichain import prompt, show, HuggingFaceEmbed, OpenAI

# Load data with embeddings (computed beforehand)

gatsby = datasets.load_from_disk("gatsby")
gatsby.add_faiss_index("embeddings")

# Fast KNN retieval prompt

@prompt(HuggingFaceEmbed("sentence-transformers/all-mpnet-base-v2"))
def get_neighbors(model, inp, k=1):
    embedding = model(inp)
    res = olympics.get_nearest_examples("embeddings", np.array(embedding), k)
    return res.examples["passages"]

@prompt(OpenAI(),
        template_file="gatsby.pmpt.tpl")
def ask(model, query, neighbors):
    return model(dict(question=query, docs=neighbors))

def gatsby(query):
    n = get_neighbors(query)
    return ask(query, n)


# $


gradio = show(gatsby,
              subprompts=[get_neighbors, ask],
              examples=["What did Gatsby do before he met Daisy?",
                        "What did the narrator do after getting back to Chicago?"],
              keys={"HF_KEY"},
              description=desc,
              code=open("gatsby.py", "r").read().split("$")[1].strip().strip("#").strip()
              )
if __name__ == "__main__":
    gradio.launch()



# + tags=["hide_inp"]
# QAPrompt().show({"question": "Who was Gatsby?", "docs": ["doc1", "doc2", "doc3"]}, "")
# # -

# show_log("gatsby.log")