zinoubm commited on
Commit
5ddfe7e
1 Parent(s): c7fa0aa

finishing the main feature

Browse files
.env ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ COLLECTION_NAME=sales_role_play
2
+ COLLECTION_SIZE=1536
3
+ QDRANT_PORT=6333
4
+ QDRANT_HOST=e47c5f9f-b98c-4eda-b28d-062e0cdbebda.eu-central-1-0.aws.cloud.qdrant.io
5
+ QDRANT_API_KEY=VIO8ss9siM8p48NpLYIhjHVKG4S5sufwFvD8g3AiSnXz_hXwQLi7tQ
6
+ OPENAI_API_KEY=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
7
+ OPENAI_ORGANIZATION=org-vzVB2Aj8Ipkxkee7l6w0e2GH
__pycache__/openai.cpython-310.pyc ADDED
Binary file (2.97 kB). View file
 
__pycache__/openai_manager.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
__pycache__/qdrant.cpython-310.pyc ADDED
Binary file (3.42 kB). View file
 
app.py CHANGED
@@ -1,9 +1,22 @@
1
  import gradio as gr
 
 
2
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
 
 
 
 
6
 
 
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
9
  iface.launch()
 
1
  import gradio as gr
2
+ from qdrant import qdrant_manager
3
+ from openai_manager import openai_manager
4
 
5
 
6
+ def generate(keywords):
7
+ try:
8
+ keywords_list = list(map(lambda x: x.strip(), keywords.split(",")))
9
+ except:
10
+ keywords_list = []
11
+ gr.Warning("Please use ',' to separate Keywords")
12
 
13
+ print("kewords", " ".join(keywords_list))
14
+ embedding = openai_manager.get_embeddings("hello my name is zinou")
15
 
16
+ points = qdrant_manager.search_point(query_vector=embedding[0])
17
+
18
+ return openai_manager.shots(points, " ".join(keywords_list))
19
+
20
+
21
+ iface = gr.Interface(fn=generate, inputs="text", outputs="text")
22
  iface.launch()
openai_manager.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ import openai
4
+
5
+
6
+ class OpenAiManager:
7
+ def __init__(self):
8
+ load_dotenv()
9
+ openai.api_key = os.getenv("OPENAI_API_KEY")
10
+ openai.organization = os.getenv("OPENAI_ORGANIZATION")
11
+
12
+ def get_completion(
13
+ self,
14
+ prompt,
15
+ model="text-davinci-003",
16
+ max_tokens=128,
17
+ temperature=0,
18
+ ):
19
+ response = None
20
+ try:
21
+ response = openai.Completion.create(
22
+ prompt=prompt,
23
+ max_tokens=max_tokens,
24
+ model=model,
25
+ temperature=temperature,
26
+ )["choices"][0]["text"]
27
+
28
+ except Exception as err:
29
+ print(f"Sorry, There was a problem \n\n {err}")
30
+
31
+ return response
32
+
33
+ def get_chat_completion(self, prompt, model="gpt-3.5-turbo"):
34
+ response = None
35
+ try:
36
+ response = (
37
+ openai.ChatCompletion.create(
38
+ model=model,
39
+ messages=[
40
+ {
41
+ "role": "system",
42
+ "content": prompt,
43
+ }
44
+ ],
45
+ )
46
+ .choices[0]
47
+ .message.content.strip()
48
+ )
49
+
50
+ except Exception as err:
51
+ print(f"Sorry, There was a problem \n\n {err}")
52
+
53
+ return response
54
+
55
+ def get_embedding(self, prompt, model="text-embedding-ada-002"):
56
+ prompt = prompt.replace("\n", " ")
57
+
58
+ embedding = None
59
+ try:
60
+ embedding = openai.Embedding.create(input=[prompt], model=model)["data"][0][
61
+ "embedding"
62
+ ]
63
+
64
+ except Exception as err:
65
+ print(f"Sorry, There was a problem {err}")
66
+
67
+ return embedding
68
+
69
+ def get_embeddings(self, prompts, model="text-embedding-ada-002"):
70
+ prompts = [prompt.replace("\n", " ") for prompt in prompts]
71
+
72
+ embeddings = None
73
+ try:
74
+ embeddings = openai.Embedding.create(input=prompts, model=model)["data"]
75
+
76
+ except Exception as err:
77
+ print(f"Sorry, There was a problem {err}")
78
+
79
+ return [embedding["embedding"] for embedding in embeddings]
80
+
81
+ def shots(self, examples, keywords):
82
+ prompt = []
83
+
84
+ for example in examples:
85
+ prompt.append(
86
+ f"""
87
+ keywords: {example.payload["keywords"]}
88
+ script: {example.payload["example"]}
89
+ """
90
+ )
91
+
92
+ prompt.append(
93
+ f"""
94
+ keywords: {keywords}
95
+ script:
96
+ """
97
+ )
98
+
99
+ prompt = "\n\n".join(prompt)
100
+
101
+ return self.get_chat_completion(prompt=prompt)
102
+
103
+
104
+ openai_manager = OpenAiManager()
qdrant.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from qdrant_client import QdrantClient
3
+ from qdrant_client.http import models
4
+
5
+ from dotenv import load_dotenv
6
+ from uuid import uuid4
7
+
8
+ load_dotenv()
9
+
10
+ COLLECTION_NAME = os.getenv("COLLECTION_NAME")
11
+ COLLECTION_SIZE = os.getenv("COLLECTION_SIZE")
12
+ QDRANT_PORT = os.getenv("QDRANT_PORT")
13
+ QDRANT_HOST = os.getenv("QDRANT_HOST")
14
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
15
+
16
+
17
+ class QdrantManager:
18
+ """
19
+ A class for managing collectionsget_collection_info in the Qdrant database.
20
+
21
+ Args:
22
+ collection_name (str): The name of the collection to manage.
23
+ collection_size (int): The maximum number of documents in the collection.
24
+ port (int): The port number for the Qdrant API.
25
+ host (str): The hostname or IP address for the Qdrant server.
26
+ api_key (str): The API key for authenticating with the Qdrant server.
27
+ recreate_collection (bool): Whether to recreate the collection if it already exists.
28
+
29
+ Attributes:
30
+ client (qdrant_client.QdrantClient): The Qdrant client object for interacting with the API.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ collection_name=COLLECTION_NAME,
36
+ collection_size: int = COLLECTION_SIZE,
37
+ port: int = QDRANT_PORT,
38
+ host=QDRANT_HOST,
39
+ api_key=QDRANT_API_KEY,
40
+ recreate_collection: bool = False,
41
+ ):
42
+ self.collection_name = collection_name
43
+ self.collection_size = collection_size
44
+ self.host = host
45
+ self.port = port
46
+ self.api_key = api_key
47
+
48
+ self.client = QdrantClient(host=host, port=port, api_key=api_key)
49
+ self.setup_collection(collection_size, recreate_collection)
50
+
51
+ def setup_collection(self, collection_size: int, recreate_collection: bool):
52
+ if recreate_collection:
53
+ self.recreate_collection()
54
+
55
+ try:
56
+ collection_info = self.get_collection_info()
57
+ current_collection_size = collection_info["vector_size"]
58
+
59
+ if current_collection_size != int(collection_size):
60
+ raise ValueError(
61
+ f"""
62
+ Existing collection {self.collection_name} has different collection size
63
+ To use the new collection configuration, you need to recreate the collection as it already exists with a different configuration.
64
+ use recreate_collection = True.
65
+ """
66
+ )
67
+
68
+ except Exception as e:
69
+ print(e)
70
+
71
+ def recreate_collection(self):
72
+ self.client.recreate_collection(
73
+ collection_name=self.collection_name,
74
+ vectors_config=models.VectorParams(
75
+ size=self.collection_size, distance=models.Distance.COSINE
76
+ ),
77
+ )
78
+
79
+ def get_collection_info(self):
80
+ collection_info = self.client.get_collection(
81
+ collection_name=self.collection_name
82
+ )
83
+
84
+ return {
85
+ "points_count": int(collection_info.points_count),
86
+ "vectors_count": int(collection_info.vectors_count),
87
+ "indexed_vectors_count": int(collection_info.indexed_vectors_count),
88
+ "vector_size": int(collection_info.config.params.vectors.size),
89
+ }
90
+
91
+ def search_point(self, query_vector, limit=5):
92
+ response = self.client.search(
93
+ collection_name=self.collection_name,
94
+ query_vector=query_vector,
95
+ limit=limit,
96
+ )
97
+
98
+ return response
99
+
100
+
101
+ qdrant_manager = QdrantManager()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ qdrant-client
3
+ openai
4
+ python-dotenv