ALLOUNE
commited on
Commit
·
62151ed
1
Parent(s):
290a90d
add purpose search
Browse files- main.py +3 -2
- src/processor.py +9 -4
main.py
CHANGED
|
@@ -28,6 +28,7 @@ dataset = load_dataset("heymenn/Technologies", streaming=True, split="train")
|
|
| 28 |
|
| 29 |
class SearchInput(BaseModel):
|
| 30 |
title: str
|
|
|
|
| 31 |
|
| 32 |
class SearchOutput(BaseModel):
|
| 33 |
title: str
|
|
@@ -53,7 +54,7 @@ def post_search(payload: SearchInput):
|
|
| 53 |
"""
|
| 54 |
Endpoint that returns a search result.
|
| 55 |
"""
|
| 56 |
-
config = {"dataset": dataset, "model": model}
|
| 57 |
res = search_and_retrieve(payload.title, config)
|
| 58 |
return res
|
| 59 |
|
|
@@ -63,7 +64,7 @@ def post_generate_and_push(payload: GenerateInput):
|
|
| 63 |
Endpoint to generate a technology and push it to the dataset
|
| 64 |
"""
|
| 65 |
|
| 66 |
-
config = {"dataset": dataset, "model": model}
|
| 67 |
res = search_and_retrieve(payload.title, config)
|
| 68 |
if res["score"] >= 0.7 and not payload.force:
|
| 69 |
raise HTTPException(status_code=500, detail=f"Cannot generate the technology a high score of {res['score']} have been found for the technology : {res['title']}")
|
|
|
|
| 28 |
|
| 29 |
class SearchInput(BaseModel):
|
| 30 |
title: str
|
| 31 |
+
type: str = "title"
|
| 32 |
|
| 33 |
class SearchOutput(BaseModel):
|
| 34 |
title: str
|
|
|
|
| 54 |
"""
|
| 55 |
Endpoint that returns a search result.
|
| 56 |
"""
|
| 57 |
+
config = {"dataset": dataset, "model": model, "type": payload.type}
|
| 58 |
res = search_and_retrieve(payload.title, config)
|
| 59 |
return res
|
| 60 |
|
|
|
|
| 64 |
Endpoint to generate a technology and push it to the dataset
|
| 65 |
"""
|
| 66 |
|
| 67 |
+
config = {"dataset": dataset, "model": model, "type": "title"}
|
| 68 |
res = search_and_retrieve(payload.title, config)
|
| 69 |
if res["score"] >= 0.7 and not payload.force:
|
| 70 |
raise HTTPException(status_code=500, detail=f"Cannot generate the technology a high score of {res['score']} have been found for the technology : {res['title']}")
|
src/processor.py
CHANGED
|
@@ -18,8 +18,11 @@ def search_and_retrieve(user_input, config):
|
|
| 18 |
purpose = row["purpose"]
|
| 19 |
|
| 20 |
cosim = model.similarity(row["embeddings"], user_embedding)
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
fuzzy_score = token_set_ratio / 100
|
| 24 |
alpha = 0.6
|
| 25 |
combined_score = alpha * cosim + (1 - alpha) * fuzzy_score
|
|
@@ -77,9 +80,11 @@ def generate_tech(user_input, user_instructions):
|
|
| 77 |
|
| 78 |
<USER_INPUT>
|
| 79 |
{user_input}
|
| 80 |
-
</USER_INPUT>
|
| 81 |
"""
|
| 82 |
|
|
|
|
|
|
|
| 83 |
client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
| 84 |
|
| 85 |
# Define the grounding tool
|
|
@@ -111,4 +116,4 @@ def send_to_dataset(data, model):
|
|
| 111 |
|
| 112 |
dataset = load_dataset("heymenn/Technologies", split="train")
|
| 113 |
updated_dataset = dataset.add_item(data)
|
| 114 |
-
updated_dataset.push_to_hub("heymenn/Technologies")
|
|
|
|
| 18 |
purpose = row["purpose"]
|
| 19 |
|
| 20 |
cosim = model.similarity(row["embeddings"], user_embedding)
|
| 21 |
+
if config["type"] == "purpose":
|
| 22 |
+
token_set_ratio = fuzz.token_set_ratio(user_input, purpose)
|
| 23 |
+
else:
|
| 24 |
+
token_set_ratio = fuzz.token_set_ratio(user_input, name)
|
| 25 |
+
|
| 26 |
fuzzy_score = token_set_ratio / 100
|
| 27 |
alpha = 0.6
|
| 28 |
combined_score = alpha * cosim + (1 - alpha) * fuzzy_score
|
|
|
|
| 80 |
|
| 81 |
<USER_INPUT>
|
| 82 |
{user_input}
|
| 83 |
+
</USER_INPUT>
|
| 84 |
"""
|
| 85 |
|
| 86 |
+
client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
| 87 |
+
|
| 88 |
client = Client(api_key=os.getenv("GEMINI_API_KEY"))
|
| 89 |
|
| 90 |
# Define the grounding tool
|
|
|
|
| 116 |
|
| 117 |
dataset = load_dataset("heymenn/Technologies", split="train")
|
| 118 |
updated_dataset = dataset.add_item(data)
|
| 119 |
+
updated_dataset.push_to_hub("heymenn/Technologies")
|