Spaces:
Runtime error
Runtime error
finishing the main feature
Browse files- .env +7 -0
- __pycache__/openai.cpython-310.pyc +0 -0
- __pycache__/openai_manager.cpython-310.pyc +0 -0
- __pycache__/qdrant.cpython-310.pyc +0 -0
- app.py +16 -3
- openai_manager.py +104 -0
- qdrant.py +101 -0
- requirements.txt +4 -0
.env
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
COLLECTION_NAME=sales_role_play
|
2 |
+
COLLECTION_SIZE=1536
|
3 |
+
QDRANT_PORT=6333
|
4 |
+
QDRANT_HOST=e47c5f9f-b98c-4eda-b28d-062e0cdbebda.eu-central-1-0.aws.cloud.qdrant.io
|
5 |
+
QDRANT_API_KEY=VIO8ss9siM8p48NpLYIhjHVKG4S5sufwFvD8g3AiSnXz_hXwQLi7tQ
|
6 |
+
OPENAI_API_KEY=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
|
7 |
+
OPENAI_ORGANIZATION=org-vzVB2Aj8Ipkxkee7l6w0e2GH
|
__pycache__/openai.cpython-310.pyc
ADDED
Binary file (2.97 kB). View file
|
|
__pycache__/openai_manager.cpython-310.pyc
ADDED
Binary file (3.11 kB). View file
|
|
__pycache__/qdrant.cpython-310.pyc
ADDED
Binary file (3.42 kB). View file
|
|
app.py
CHANGED
@@ -1,9 +1,22 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
|
4 |
-
def
|
5 |
-
|
|
|
|
|
|
|
|
|
6 |
|
|
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
9 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from qdrant import qdrant_manager
|
3 |
+
from openai_manager import openai_manager
|
4 |
|
5 |
|
6 |
+
def generate(keywords):
|
7 |
+
try:
|
8 |
+
keywords_list = list(map(lambda x: x.strip(), keywords.split(",")))
|
9 |
+
except:
|
10 |
+
keywords_list = []
|
11 |
+
gr.Warning("Please use ',' to separate Keywords")
|
12 |
|
13 |
+
print("kewords", " ".join(keywords_list))
|
14 |
+
embedding = openai_manager.get_embeddings("hello my name is zinou")
|
15 |
|
16 |
+
points = qdrant_manager.search_point(query_vector=embedding[0])
|
17 |
+
|
18 |
+
return openai_manager.shots(points, " ".join(keywords_list))
|
19 |
+
|
20 |
+
|
21 |
+
iface = gr.Interface(fn=generate, inputs="text", outputs="text")
|
22 |
iface.launch()
|
openai_manager.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import openai
|
4 |
+
|
5 |
+
|
6 |
+
class OpenAiManager:
|
7 |
+
def __init__(self):
|
8 |
+
load_dotenv()
|
9 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
10 |
+
openai.organization = os.getenv("OPENAI_ORGANIZATION")
|
11 |
+
|
12 |
+
def get_completion(
|
13 |
+
self,
|
14 |
+
prompt,
|
15 |
+
model="text-davinci-003",
|
16 |
+
max_tokens=128,
|
17 |
+
temperature=0,
|
18 |
+
):
|
19 |
+
response = None
|
20 |
+
try:
|
21 |
+
response = openai.Completion.create(
|
22 |
+
prompt=prompt,
|
23 |
+
max_tokens=max_tokens,
|
24 |
+
model=model,
|
25 |
+
temperature=temperature,
|
26 |
+
)["choices"][0]["text"]
|
27 |
+
|
28 |
+
except Exception as err:
|
29 |
+
print(f"Sorry, There was a problem \n\n {err}")
|
30 |
+
|
31 |
+
return response
|
32 |
+
|
33 |
+
def get_chat_completion(self, prompt, model="gpt-3.5-turbo"):
|
34 |
+
response = None
|
35 |
+
try:
|
36 |
+
response = (
|
37 |
+
openai.ChatCompletion.create(
|
38 |
+
model=model,
|
39 |
+
messages=[
|
40 |
+
{
|
41 |
+
"role": "system",
|
42 |
+
"content": prompt,
|
43 |
+
}
|
44 |
+
],
|
45 |
+
)
|
46 |
+
.choices[0]
|
47 |
+
.message.content.strip()
|
48 |
+
)
|
49 |
+
|
50 |
+
except Exception as err:
|
51 |
+
print(f"Sorry, There was a problem \n\n {err}")
|
52 |
+
|
53 |
+
return response
|
54 |
+
|
55 |
+
def get_embedding(self, prompt, model="text-embedding-ada-002"):
|
56 |
+
prompt = prompt.replace("\n", " ")
|
57 |
+
|
58 |
+
embedding = None
|
59 |
+
try:
|
60 |
+
embedding = openai.Embedding.create(input=[prompt], model=model)["data"][0][
|
61 |
+
"embedding"
|
62 |
+
]
|
63 |
+
|
64 |
+
except Exception as err:
|
65 |
+
print(f"Sorry, There was a problem {err}")
|
66 |
+
|
67 |
+
return embedding
|
68 |
+
|
69 |
+
def get_embeddings(self, prompts, model="text-embedding-ada-002"):
|
70 |
+
prompts = [prompt.replace("\n", " ") for prompt in prompts]
|
71 |
+
|
72 |
+
embeddings = None
|
73 |
+
try:
|
74 |
+
embeddings = openai.Embedding.create(input=prompts, model=model)["data"]
|
75 |
+
|
76 |
+
except Exception as err:
|
77 |
+
print(f"Sorry, There was a problem {err}")
|
78 |
+
|
79 |
+
return [embedding["embedding"] for embedding in embeddings]
|
80 |
+
|
81 |
+
def shots(self, examples, keywords):
|
82 |
+
prompt = []
|
83 |
+
|
84 |
+
for example in examples:
|
85 |
+
prompt.append(
|
86 |
+
f"""
|
87 |
+
keywords: {example.payload["keywords"]}
|
88 |
+
script: {example.payload["example"]}
|
89 |
+
"""
|
90 |
+
)
|
91 |
+
|
92 |
+
prompt.append(
|
93 |
+
f"""
|
94 |
+
keywords: {keywords}
|
95 |
+
script:
|
96 |
+
"""
|
97 |
+
)
|
98 |
+
|
99 |
+
prompt = "\n\n".join(prompt)
|
100 |
+
|
101 |
+
return self.get_chat_completion(prompt=prompt)
|
102 |
+
|
103 |
+
|
104 |
+
openai_manager = OpenAiManager()
|
qdrant.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from qdrant_client import QdrantClient
|
3 |
+
from qdrant_client.http import models
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from uuid import uuid4
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
|
10 |
+
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
|
11 |
+
COLLECTION_SIZE = os.getenv("COLLECTION_SIZE")
|
12 |
+
QDRANT_PORT = os.getenv("QDRANT_PORT")
|
13 |
+
QDRANT_HOST = os.getenv("QDRANT_HOST")
|
14 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
15 |
+
|
16 |
+
|
17 |
+
class QdrantManager:
|
18 |
+
"""
|
19 |
+
A class for managing collectionsget_collection_info in the Qdrant database.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
collection_name (str): The name of the collection to manage.
|
23 |
+
collection_size (int): The maximum number of documents in the collection.
|
24 |
+
port (int): The port number for the Qdrant API.
|
25 |
+
host (str): The hostname or IP address for the Qdrant server.
|
26 |
+
api_key (str): The API key for authenticating with the Qdrant server.
|
27 |
+
recreate_collection (bool): Whether to recreate the collection if it already exists.
|
28 |
+
|
29 |
+
Attributes:
|
30 |
+
client (qdrant_client.QdrantClient): The Qdrant client object for interacting with the API.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
collection_name=COLLECTION_NAME,
|
36 |
+
collection_size: int = COLLECTION_SIZE,
|
37 |
+
port: int = QDRANT_PORT,
|
38 |
+
host=QDRANT_HOST,
|
39 |
+
api_key=QDRANT_API_KEY,
|
40 |
+
recreate_collection: bool = False,
|
41 |
+
):
|
42 |
+
self.collection_name = collection_name
|
43 |
+
self.collection_size = collection_size
|
44 |
+
self.host = host
|
45 |
+
self.port = port
|
46 |
+
self.api_key = api_key
|
47 |
+
|
48 |
+
self.client = QdrantClient(host=host, port=port, api_key=api_key)
|
49 |
+
self.setup_collection(collection_size, recreate_collection)
|
50 |
+
|
51 |
+
def setup_collection(self, collection_size: int, recreate_collection: bool):
|
52 |
+
if recreate_collection:
|
53 |
+
self.recreate_collection()
|
54 |
+
|
55 |
+
try:
|
56 |
+
collection_info = self.get_collection_info()
|
57 |
+
current_collection_size = collection_info["vector_size"]
|
58 |
+
|
59 |
+
if current_collection_size != int(collection_size):
|
60 |
+
raise ValueError(
|
61 |
+
f"""
|
62 |
+
Existing collection {self.collection_name} has different collection size
|
63 |
+
To use the new collection configuration, you need to recreate the collection as it already exists with a different configuration.
|
64 |
+
use recreate_collection = True.
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
|
68 |
+
except Exception as e:
|
69 |
+
print(e)
|
70 |
+
|
71 |
+
def recreate_collection(self):
|
72 |
+
self.client.recreate_collection(
|
73 |
+
collection_name=self.collection_name,
|
74 |
+
vectors_config=models.VectorParams(
|
75 |
+
size=self.collection_size, distance=models.Distance.COSINE
|
76 |
+
),
|
77 |
+
)
|
78 |
+
|
79 |
+
def get_collection_info(self):
|
80 |
+
collection_info = self.client.get_collection(
|
81 |
+
collection_name=self.collection_name
|
82 |
+
)
|
83 |
+
|
84 |
+
return {
|
85 |
+
"points_count": int(collection_info.points_count),
|
86 |
+
"vectors_count": int(collection_info.vectors_count),
|
87 |
+
"indexed_vectors_count": int(collection_info.indexed_vectors_count),
|
88 |
+
"vector_size": int(collection_info.config.params.vectors.size),
|
89 |
+
}
|
90 |
+
|
91 |
+
def search_point(self, query_vector, limit=5):
|
92 |
+
response = self.client.search(
|
93 |
+
collection_name=self.collection_name,
|
94 |
+
query_vector=query_vector,
|
95 |
+
limit=limit,
|
96 |
+
)
|
97 |
+
|
98 |
+
return response
|
99 |
+
|
100 |
+
|
101 |
+
qdrant_manager = QdrantManager()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
qdrant-client
|
3 |
+
openai
|
4 |
+
python-dotenv
|