ManojINaik commited on
Commit
699be26
·
verified ·
1 Parent(s): 5f8ebb7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +33 -26
main.py CHANGED
@@ -1,11 +1,10 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
4
- import uvicorn
5
-
6
 
7
  app = FastAPI()
8
 
 
9
  client = InferenceClient("ManojINaik/codsw")
10
 
11
  class Item(BaseModel):
@@ -26,29 +25,37 @@ def format_prompt(message, history):
26
  return prompt
27
 
28
  def generate(item: Item):
29
- temperature = float(item.temperature)
30
- if temperature < 1e-2:
31
- temperature = 1e-2
32
- top_p = float(item.top_p)
33
-
34
- generate_kwargs = dict(
35
- temperature=temperature,
36
- max_new_tokens=item.max_new_tokens,
37
- top_p=top_p,
38
- repetition_penalty=item.repetition_penalty,
39
- do_sample=True,
40
- seed=42,
41
- )
42
-
43
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
45
- output = ""
46
-
47
- for response in stream:
48
- output += response.token.text
49
- return output
 
 
 
 
 
 
 
 
 
50
 
51
  @app.post("/generate/")
52
  async def generate_text(item: Item):
53
  return {"response": generate(item)}
54
-
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from huggingface_hub import InferenceClient, BadRequestError
 
 
4
 
5
  app = FastAPI()
6
 
7
+ # Use your model
8
  client = InferenceClient("ManojINaik/codsw")
9
 
10
  class Item(BaseModel):
 
25
  return prompt
26
 
27
  def generate(item: Item):
28
+ try:
29
+ # Ensure valid temperature
30
+ temperature = max(float(item.temperature), 1e-2)
31
+ top_p = float(item.top_p)
32
+
33
+ generate_kwargs = {
34
+ "temperature": temperature,
35
+ "max_new_tokens": item.max_new_tokens,
36
+ "top_p": top_p,
37
+ "repetition_penalty": item.repetition_penalty,
38
+ "do_sample": True,
39
+ "seed": 42,
40
+ }
41
+
42
+ # Format the prompt
43
+ formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
+
45
+ # Call text_generation on your model
46
+ stream = client.text_generation(
47
+ inputs=formatted_prompt,
48
+ **generate_kwargs,
49
+ stream=True,
50
+ )
51
+ output = "".join([response.token.text for response in stream])
52
+ return output
53
+
54
+ except BadRequestError as e:
55
+ raise HTTPException(status_code=400, detail=f"Bad request: {str(e)}")
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
58
 
59
  @app.post("/generate/")
60
  async def generate_text(item: Item):
61
  return {"response": generate(item)}