sguertl commited on
Commit
6dc93bc
·
verified ·
1 Parent(s): 8706698

Use Llama 70B

Browse files
Files changed (1) hide show
  1. app.py +67 -15
app.py CHANGED
@@ -2,37 +2,89 @@ from huggingface_hub import InferenceClient
2
  from fastapi import FastAPI, Request
3
  from pydantic import BaseModel
4
  import uvicorn
 
 
5
  import os
6
 
7
  app = FastAPI()
8
 
9
- MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
10
  HF_TOKEN = os.environ["HF_TOKEN"]
 
11
 
12
  client = InferenceClient(model=MODEL, token=HF_TOKEN)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class Prompt(BaseModel):
15
  message: str
 
16
 
17
  @app.post("/chat")
18
  async def chat(prompt: Prompt):
 
19
  print("Received POST request")
20
  print("Message:", prompt.message)
21
- system_prompt = (
22
- "You are a beginner programming student helping a peer. "
23
- "Offer hints, ask questions, and support understanding—don’t give full solutions."
24
- )
25
- full_prompt = f"<s>[INST] <<SYS>>{system_prompt}<</SYS>>\n{prompt.message} [/INST]"
26
 
27
- print("Full Prompt:", full_prompt)
 
 
 
 
 
28
 
29
- output = client.text_generation(
30
- prompt=full_prompt,
31
- max_new_tokens=200,
32
- temperature=0.7,
33
- top_p=0.95,
34
- do_sample=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
 
 
36
 
37
- print("Text generation done", output.strip())
38
- return {"reply": output.strip()}
 
2
  from fastapi import FastAPI, Request
3
  from pydantic import BaseModel
4
  import uvicorn
5
+ import requests
6
+ import re
7
  import os
8
 
9
  app = FastAPI()
10
 
11
+ MODEL = "meta-llama/Llama-3.3-70B-Instruct"
12
  HF_TOKEN = os.environ["HF_TOKEN"]
13
+ PROMPTS_DOC_URL = os.environ["PROMPTS"]
14
 
15
  client = InferenceClient(model=MODEL, token=HF_TOKEN)
16
 
17
+ def fetch_prompts_from_google_doc():
18
+ print("Fetching prompts from Google Doc...")
19
+ response = requests.get(PROMPTS_DOC_URL)
20
+ if response.status_code != 200:
21
+ raise Exception("Failed to fetch document")
22
+
23
+ text = response.text
24
+ prompts = {}
25
+
26
+ pattern = r"\{BEGIN (.*?)\}([\s\S]*?)\{END \1\}"
27
+ matches = re.findall(pattern, text)
28
+
29
+ for key, content in matches:
30
+ prompts[key.strip()] = content.strip()
31
+
32
+ return prompts
33
+
34
  class Prompt(BaseModel):
35
  message: str
36
+ code: str
37
 
38
  @app.post("/chat")
39
  async def chat(prompt: Prompt):
40
+ prompts = fetch_prompts_from_google_doc()
41
  print("Received POST request")
42
  print("Message:", prompt.message)
 
 
 
 
 
43
 
44
+ system_prompt = f"""
45
+ ### Unit Information ###
46
+ {prompts['UNIT_INFORMATION']}
47
+
48
+ ### Role Description ###
49
+ {prompts['ROLE_DESCRIPTION']}
50
 
51
+ ### Topic Information ###
52
+ {prompts['TOPIC_INFORMATION']}
53
+
54
+ ### Task Description ###
55
+ {prompts['TASK_DESCRIPTION']}
56
+
57
+ ### Reference Solution ###
58
+ {prompts['REFERENCE_SOLUTION']}
59
+
60
+ ### Behavioral Instructions ###
61
+ {prompts['BEHAVIORAL_INSTRUCTIONS']}
62
+ """
63
+
64
+ user_prompt = f"""
65
+ ### Message ###
66
+ {prompt.message}
67
+
68
+ ### Code ###
69
+ {prompt.code}
70
+ """
71
+
72
+ response = client.chat_completion(
73
+ [
74
+ {
75
+ "role": "system",
76
+ "content": system_prompt,
77
+ },
78
+ {
79
+ "role": "user",
80
+ "content": user_prompt
81
+ },
82
+ ],
83
+ max_tokens=2048,
84
+ temperature=0.2,
85
  )
86
+
87
+ text_response = response["choices"][0]["message"]["content"]
88
 
89
+ print("Text generation done", text_response.strip())
90
+ return {"reply": text_response.strip()}