raghavNCI commited on
Commit
e30a3df
Β·
1 Parent(s): f0f712f

mistral connection with inference

Browse files
models_initialization/mistral_registry.py CHANGED
@@ -1,29 +1,19 @@
1
  import os
2
  import json
3
- import boto3
4
- from botocore.config import Config
5
- from botocore.exceptions import BotoCoreError, ClientError
6
 
7
- # ──────────────────────────────────────────────────────────────
8
- # Environment variables you need (add them in your HF Space)
9
- # ──────────────────────────────────────────────────────────────
10
- # AWS_ACCESS_KEY_ID
11
- # AWS_SECRET_ACCESS_KEY
12
- # AWS_REGION β†’ e.g. "us-east-1"
13
- # SAGEMAKER_ENDPOINT_NAME β†’ e.g. "mistral-endpoint"
14
- # ──────────────────────────────────────────────────────────────
15
 
16
- AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
17
- ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT_NAME", "mistral-endpoint")
18
 
19
- # Optional: configure retries / timeouts
20
- boto_cfg = Config(
21
- retries={"max_attempts": 3, "mode": "standard"},
22
- connect_timeout=10,
23
- read_timeout=120,
24
- )
25
-
26
- sm_client = boto3.client("sagemaker-runtime", region_name=AWS_REGION, config=boto_cfg)
27
 
28
 
29
  def mistral_generate(
@@ -32,8 +22,8 @@ def mistral_generate(
32
  temperature: float = 0.7,
33
  ) -> str:
34
  """
35
- Call the SageMaker endpoint that hosts Mistral-7B.
36
- Returns the generated text or an empty string on failure.
37
  """
38
  payload = {
39
  "inputs": prompt,
@@ -44,22 +34,26 @@ def mistral_generate(
44
  }
45
 
46
  try:
47
- # Invoke the endpoint
48
- response = sm_client.invoke_endpoint(
49
- EndpointName=ENDPOINT,
50
- ContentType="application/json",
51
- Body=json.dumps(payload).encode("utf-8"),
52
  )
53
-
54
- # SageMaker returns a byte stream β†’ decode & load JSON
55
- result = json.loads(response["Body"].read())
56
-
57
- if isinstance(result, list) and result:
58
- return result[0].get("generated_text", "").strip()
59
-
60
- except (BotoCoreError, ClientError) as e:
61
- # Log SageMaker errors (throttling, auth, etc.)
62
- print("❌ SageMaker invocation error:", str(e))
 
 
 
 
63
 
64
  except Exception as e:
65
  print("❌ Unknown error:", str(e))
 
1
  import os
2
  import json
3
+ import requests
4
+ from requests.exceptions import RequestException
 
5
 
6
+ HF_ENDPOINT_URL = os.getenv("HF_ENDPOINT_URL")
7
+ HF_ENDPOINT_TOKEN = os.getenv("HF_ENDPOINT_TOKEN")
 
 
 
 
 
 
8
 
9
+ assert HF_ENDPOINT_URL, "❌ HF_ENDPOINT_URL is not set"
10
+ assert HF_ENDPOINT_TOKEN, "❌ HF_ENDPOINT_TOKEN is not set"
11
 
12
+ HEADERS = {
13
+ "Authorization": f"Bearer {HF_ENDPOINT_TOKEN}",
14
+ "Content-Type": "application/json",
15
+ "Accept": "application/json",
16
+ }
 
 
 
17
 
18
 
19
  def mistral_generate(
 
22
  temperature: float = 0.7,
23
  ) -> str:
24
  """
25
+ Call the Hugging Face Inference Endpoint that hosts Mistral-7B.
26
+ Returns the generated text, or an empty string on failure.
27
  """
28
  payload = {
29
  "inputs": prompt,
 
34
  }
35
 
36
  try:
37
+ r = requests.post(
38
+ HF_ENDPOINT_URL,
39
+ headers=HEADERS,
40
+ json=payload,
41
+ timeout=90, # HF spins up cold endpoints too
42
  )
43
+ r.raise_for_status()
44
+ data = r.json()
45
+
46
+ # HF Endpoints usually return a *list* of dicts
47
+ if isinstance(data, list) and data:
48
+ return data[0].get("generated_text", "").strip()
49
+ # Some endpoints return a single dict
50
+ if isinstance(data, dict) and "generated_text" in data:
51
+ return data["generated_text"].strip()
52
+
53
+ except RequestException as e:
54
+ print("❌ HF Endpoint error:", str(e))
55
+ if e.response is not None:
56
+ print("Endpoint said:", e.response.text[:300])
57
 
58
  except Exception as e:
59
  print("❌ Unknown error:", str(e))