Spaces:
Running
Running
acumplid
commited on
Commit
·
5b3eab8
1
Parent(s):
bc13def
replace endpoint
Browse files
rag.py
CHANGED
@@ -4,7 +4,7 @@ import requests
|
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
import torch
|
6 |
from openai import OpenAI
|
7 |
-
from huggingface_hub import snapshot_download
|
8 |
|
9 |
from langchain_community.vectorstores import FAISS
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
@@ -105,26 +105,40 @@ class RAG:
|
|
105 |
|
106 |
def predict_completion(self, instruction, context, model_parameters):
|
107 |
|
108 |
-
client = OpenAI(
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
chat_completion = client.chat.completions.create(
|
116 |
-
model="tgi",
|
117 |
messages=[
|
118 |
{"role": "user", "content": instruction}
|
119 |
],
|
120 |
temperature=model_parameters["temperature"],
|
121 |
max_tokens=model_parameters["max_new_tokens"],
|
|
|
|
|
122 |
stream=False,
|
123 |
stop=["<|im_end|>"],
|
124 |
-
extra_body = {
|
125 |
-
"presence_penalty": model_parameters["repetition_penalty"] - 2,
|
126 |
-
"do_sample": False
|
127 |
-
}
|
128 |
)
|
129 |
|
130 |
response = chat_completion.choices[0].message.content
|
|
|
4 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
import torch
|
6 |
from openai import OpenAI
|
7 |
+
from huggingface_hub import snapshot_download, InferenceClient
|
8 |
|
9 |
from langchain_community.vectorstores import FAISS
|
10 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
105 |
|
106 |
def predict_completion(self, instruction, context, model_parameters):
|
107 |
|
108 |
+
# client = OpenAI(
|
109 |
+
# base_url=os.getenv("MODEL"),
|
110 |
+
# api_key=os.getenv("HF_TOKEN")
|
111 |
+
# )
|
112 |
+
|
113 |
+
# query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
|
114 |
+
|
115 |
+
# chat_completion = client.chat.completions.create(
|
116 |
+
# model="tgi",
|
117 |
+
# messages=[
|
118 |
+
# {"role": "user", "content": instruction}
|
119 |
+
# ],
|
120 |
+
# temperature=model_parameters["temperature"],
|
121 |
+
# max_tokens=model_parameters["max_new_tokens"],
|
122 |
+
# stream=False,
|
123 |
+
# stop=["<|im_end|>"],
|
124 |
+
# extra_body = {
|
125 |
+
# "presence_penalty": model_parameters["repetition_penalty"] - 2,
|
126 |
+
# "do_sample": False
|
127 |
+
# }
|
128 |
+
# )
|
129 |
+
|
130 |
+
client = InferenceClient(api_key=os.getenv("HF_TOKEN"),model="meta-llama/Llama-3.3-70B-Instruct")
|
131 |
|
132 |
chat_completion = client.chat.completions.create(
|
|
|
133 |
messages=[
|
134 |
{"role": "user", "content": instruction}
|
135 |
],
|
136 |
temperature=model_parameters["temperature"],
|
137 |
max_tokens=model_parameters["max_new_tokens"],
|
138 |
+
presence_penalty= model_parameters["repetition_penalty"] - 2,
|
139 |
+
top_p= 0.7,
|
140 |
stream=False,
|
141 |
stop=["<|im_end|>"],
|
|
|
|
|
|
|
|
|
142 |
)
|
143 |
|
144 |
response = chat_completion.choices[0].message.content
|