raghavNCI commited on
Commit
17fbf3d
·
1 Parent(s): 975c3a7

one more try mistral

Browse files
models_initialization/mistral_registry.py CHANGED
@@ -1,35 +1,51 @@
1
  import os
2
- import json
3
  import requests
 
4
  from dotenv import load_dotenv
5
 
6
  load_dotenv()
7
 
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
  HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
10
 
11
  HEADERS = {
12
  "Authorization": f"Bearer {HF_TOKEN}",
13
- "Content-Type": "application/json"
14
  }
15
 
16
- def mistral_generate(prompt: str, max_new_tokens=128, temperature=0.7) -> str:
 
 
 
 
 
 
 
17
  payload = {
18
  "inputs": prompt,
19
  "parameters": {
20
  "max_new_tokens": max_new_tokens,
21
- "temperature": temperature,
22
- "wait_for_model": True
23
  }
24
  }
25
 
26
  try:
27
- response = requests.post(HF_API_URL, headers=HEADERS, data=json.dumps(payload), timeout=30)
28
- response.raise_for_status()
29
- result = response.json()
30
- if isinstance(result, list) and result:
31
- return result[0].get("generated_text", "").strip()
32
- except Exception as e:
33
- print("Mistral API error:", e)
 
 
 
 
 
 
 
 
 
 
34
 
35
  return ""
 
1
  import os
 
2
  import requests
3
+ import json
4
  from dotenv import load_dotenv
5
 
6
  load_dotenv()
7
 
8
+ HF_TOKEN = os.getenv("HF_TOKEN")
9
  HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
10
 
11
  HEADERS = {
12
  "Authorization": f"Bearer {HF_TOKEN}",
13
+ "Content-Type": "application/json"
14
  }
15
 
16
+ def mistral_generate(prompt: str,
17
+ max_new_tokens: int = 128,
18
+ temperature: float = 0.7) -> str:
19
+ """
20
+ Call the HF Inference-API for Mistral-7B-Instruct-v0.3.
21
+ - Automatically waits while the model spins up (`wait_for_model=true`).
22
+ - Returns the generated text or an empty string on failure.
23
+ """
24
  payload = {
25
  "inputs": prompt,
26
  "parameters": {
27
  "max_new_tokens": max_new_tokens,
28
+ "temperature": temperature
 
29
  }
30
  }
31
 
32
  try:
33
+ r = requests.post(
34
+ HF_API_URL,
35
+ headers=HEADERS,
36
+ params={"wait_for_model": "true"}, # key change
37
+ json=payload, # use `json=` not `data=`
38
+ timeout=90 # give the model time to load
39
+ )
40
+ r.raise_for_status()
41
+ data = r.json()
42
+
43
+ # HF returns a list of generated texts for standard text-generation models
44
+ if isinstance(data, list) and data:
45
+ return data[0].get("generated_text", "").strip()
46
+
47
+ except requests.exceptions.RequestException as e:
48
+ # You might want to log `r.text` as well for quota or auth errors
49
+ print("❌ Mistral API error:", str(e))
50
 
51
  return ""