acumplid commited on
Commit
5b3eab8
·
1 Parent(s): bc13def

replace endpoint

Browse files
Files changed (1) hide show
  1. rag.py +26 -12
rag.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  from openai import OpenAI
7
- from huggingface_hub import snapshot_download
8
 
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -105,26 +105,40 @@ class RAG:
105
 
106
  def predict_completion(self, instruction, context, model_parameters):
107
 
108
- client = OpenAI(
109
- base_url=os.getenv("MODEL"),
110
- api_key=os.getenv("HF_TOKEN")
111
- )
112
-
113
- query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  chat_completion = client.chat.completions.create(
116
- model="tgi",
117
  messages=[
118
  {"role": "user", "content": instruction}
119
  ],
120
  temperature=model_parameters["temperature"],
121
  max_tokens=model_parameters["max_new_tokens"],
 
 
122
  stream=False,
123
  stop=["<|im_end|>"],
124
- extra_body = {
125
- "presence_penalty": model_parameters["repetition_penalty"] - 2,
126
- "do_sample": False
127
- }
128
  )
129
 
130
  response = chat_completion.choices[0].message.content
 
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  import torch
6
  from openai import OpenAI
7
+ from huggingface_hub import snapshot_download, InferenceClient
8
 
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
105
 
106
  def predict_completion(self, instruction, context, model_parameters):
107
 
108
+ # client = OpenAI(
109
+ # base_url=os.getenv("MODEL"),
110
+ # api_key=os.getenv("HF_TOKEN")
111
+ # )
112
+
113
+ # query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
114
+
115
+ # chat_completion = client.chat.completions.create(
116
+ # model="tgi",
117
+ # messages=[
118
+ # {"role": "user", "content": instruction}
119
+ # ],
120
+ # temperature=model_parameters["temperature"],
121
+ # max_tokens=model_parameters["max_new_tokens"],
122
+ # stream=False,
123
+ # stop=["<|im_end|>"],
124
+ # extra_body = {
125
+ # "presence_penalty": model_parameters["repetition_penalty"] - 2,
126
+ # "do_sample": False
127
+ # }
128
+ # )
129
+
130
+ client = InferenceClient(api_key=os.getenv("HF_TOKEN"),model="meta-llama/Llama-3.3-70B-Instruct")
131
 
132
  chat_completion = client.chat.completions.create(
 
133
  messages=[
134
  {"role": "user", "content": instruction}
135
  ],
136
  temperature=model_parameters["temperature"],
137
  max_tokens=model_parameters["max_new_tokens"],
138
+ presence_penalty= model_parameters["repetition_penalty"] - 2,
139
+ top_p= 0.7,
140
  stream=False,
141
  stop=["<|im_end|>"],
 
 
 
 
142
  )
143
 
144
  response = chat_completion.choices[0].message.content