raghavNCI commited on
Commit
f0f712f
Β·
1 Parent(s): 17fbf3d

mistral after hosted on aws

Browse files
Dockerfile CHANGED
@@ -11,6 +11,7 @@ WORKDIR /app
11
 
12
  COPY --chown=user ./requirements.txt requirements.txt
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
 
14
 
15
  COPY --chown=user . /app
16
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
11
 
12
  COPY --chown=user ./requirements.txt requirements.txt
13
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+ RUN pip install --no-cache-dir boto3
15
 
16
  COPY --chown=user . /app
17
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
models_initialization/mistral_registry.py CHANGED
@@ -1,51 +1,67 @@
1
  import os
2
- import requests
3
  import json
4
- from dotenv import load_dotenv
 
 
5
 
6
- load_dotenv()
 
 
 
 
 
 
 
7
 
8
- HF_TOKEN = os.getenv("HF_TOKEN")
9
- HF_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3"
10
 
11
- HEADERS = {
12
- "Authorization": f"Bearer {HF_TOKEN}",
13
- "Content-Type": "application/json"
14
- }
 
 
15
 
16
- def mistral_generate(prompt: str,
17
- max_new_tokens: int = 128,
18
- temperature: float = 0.7) -> str:
 
 
 
 
 
19
  """
20
- Call the HF Inference-API for Mistral-7B-Instruct-v0.3.
21
- - Automatically waits while the model spins up (`wait_for_model=true`).
22
- - Returns the generated text or an empty string on failure.
23
  """
24
  payload = {
25
  "inputs": prompt,
26
  "parameters": {
27
  "max_new_tokens": max_new_tokens,
28
- "temperature": temperature
29
- }
30
  }
31
 
32
  try:
33
- r = requests.post(
34
- HF_API_URL,
35
- headers=HEADERS,
36
- params={"wait_for_model": "true"}, # ⭐ key change
37
- json=payload, # use `json=` not `data=`
38
- timeout=90 # give the model time to load
39
  )
40
- r.raise_for_status()
41
- data = r.json()
42
 
43
- # HF returns a list of generated texts for standard text-generation models
44
- if isinstance(data, list) and data:
45
- return data[0].get("generated_text", "").strip()
 
 
 
 
 
 
46
 
47
- except requests.exceptions.RequestException as e:
48
- # You might want to log `r.text` as well for quota or auth errors
49
- print("❌ Mistral API error:", str(e))
50
 
51
  return ""
 
1
  import os
 
2
  import json
3
+ import boto3
4
+ from botocore.config import Config
5
+ from botocore.exceptions import BotoCoreError, ClientError
6
 
7
+ # ──────────────────────────────────────────────────────────────
8
+ # Environment variables you need (add them in your HF Space)
9
+ # ──────────────────────────────────────────────────────────────
10
+ # AWS_ACCESS_KEY_ID
11
+ # AWS_SECRET_ACCESS_KEY
12
+ # AWS_REGION β†’ e.g. "us-east-1"
13
+ # SAGEMAKER_ENDPOINT_NAME β†’ e.g. "mistral-endpoint"
14
+ # ──────────────────────────────────────────────────────────────
15
 
16
+ AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
17
+ ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT_NAME", "mistral-endpoint")
18
 
19
+ # Optional: configure retries / timeouts
20
+ boto_cfg = Config(
21
+ retries={"max_attempts": 3, "mode": "standard"},
22
+ connect_timeout=10,
23
+ read_timeout=120,
24
+ )
25
 
26
+ sm_client = boto3.client("sagemaker-runtime", region_name=AWS_REGION, config=boto_cfg)
27
+
28
+
29
+ def mistral_generate(
30
+ prompt: str,
31
+ max_new_tokens: int = 128,
32
+ temperature: float = 0.7,
33
+ ) -> str:
34
  """
35
+ Call the SageMaker endpoint that hosts Mistral-7B.
36
+ Returns the generated text or an empty string on failure.
 
37
  """
38
  payload = {
39
  "inputs": prompt,
40
  "parameters": {
41
  "max_new_tokens": max_new_tokens,
42
+ "temperature": temperature,
43
+ },
44
  }
45
 
46
  try:
47
+ # Invoke the endpoint
48
+ response = sm_client.invoke_endpoint(
49
+ EndpointName=ENDPOINT,
50
+ ContentType="application/json",
51
+ Body=json.dumps(payload).encode("utf-8"),
 
52
  )
 
 
53
 
54
+ # SageMaker returns a byte stream β†’ decode & load JSON
55
+ result = json.loads(response["Body"].read())
56
+
57
+ if isinstance(result, list) and result:
58
+ return result[0].get("generated_text", "").strip()
59
+
60
+ except (BotoCoreError, ClientError) as e:
61
+ # Log SageMaker errors (throttling, auth, etc.)
62
+ print("❌ SageMaker invocation error:", str(e))
63
 
64
+ except Exception as e:
65
+ print("❌ Unknown error:", str(e))
 
66
 
67
  return ""