ManojINaik commited on
Commit
d628814
·
verified ·
1 Parent(s): 9c76e91

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +80 -48
main.py CHANGED
@@ -1,59 +1,91 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from huggingface_hub import InferenceClient
 
 
4
 
5
- app = FastAPI()
6
 
7
- # Use your model
8
- client = InferenceClient("ManojINaik/codsw")
 
 
 
 
9
 
10
- class Item(BaseModel):
 
11
  prompt: str
12
- history: list
13
- system_prompt: str
14
- temperature: float = 0.0
15
- max_new_tokens: int = 1048
16
- top_p: float = 0.15
17
- repetition_penalty: float = 1.0
18
-
19
- def format_prompt(message, history):
20
- prompt = "<s>"
21
- for user_prompt, bot_response in history:
22
- prompt += f"[INST] {user_prompt} [/INST]"
23
- prompt += f" {bot_response}</s> "
24
- prompt += f"[INST] {message} [/INST]"
25
- return prompt
26
-
27
- def generate(item: Item):
28
  try:
29
- # Ensure valid temperature
30
- temperature = max(float(item.temperature), 1e-2)
31
- top_p = float(item.top_p)
32
-
33
- generate_kwargs = {
34
- "temperature": temperature,
35
- "max_new_tokens": item.max_new_tokens,
36
- "top_p": top_p,
37
- "repetition_penalty": item.repetition_penalty,
38
- "do_sample": True,
39
- "seed": 42,
40
- }
41
-
42
- # Format the prompt
43
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
44
-
45
- # Call text_generation on your model (correct argument: formatted_prompt)
46
- stream = client.text_generation(
47
- formatted_prompt, # Use the formatted prompt directly
48
- **generate_kwargs,
49
- stream=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
- output = "".join([response.token.text for response in stream])
52
- return output
53
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
- raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
56
 
57
- @app.post("/generate/")
58
- async def generate_text(item: Item):
59
- return {"response": generate(item)}
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import torch
5
+ from typing import Optional, List
6
 
7
+ app = FastAPI(title="LLM API", description="API for interacting with LLaMA model")
8
 
9
+ # Model configuration
10
+ class ModelConfig:
11
+ model_name = "ManojINaik/Strength_weakness" # Your fine-tuned model
12
+ device = "cpu"
13
+ max_length = 200
14
+ temperature = 0.7
15
 
16
+ # Request/Response models
17
+ class GenerateRequest(BaseModel):
18
  prompt: str
19
+ history: Optional[List[str]] = []
20
+ system_prompt: Optional[str] = "You are a very powerful AI assistant."
21
+ max_length: Optional[int] = 200
22
+ temperature: Optional[float] = 0.7
23
+
24
+ class GenerateResponse(BaseModel):
25
+ response: str
26
+
27
+ # Global variables for model and tokenizer
28
+ model = None
29
+ tokenizer = None
30
+ generator = None
31
+
32
+ @app.on_event("startup")
33
+ async def load_model():
34
+ global model, tokenizer, generator
35
  try:
36
+ print("Loading model and tokenizer...")
37
+ tokenizer = AutoTokenizer.from_pretrained(ModelConfig.model_name)
38
+ model = AutoModelForCausalLM.from_pretrained(
39
+ ModelConfig.model_name,
40
+ torch_dtype=torch.float32,
41
+ device_map=ModelConfig.device,
42
+ low_cpu_mem_usage=True
43
+ )
44
+ generator = pipeline(
45
+ "text-generation",
46
+ model=model,
47
+ tokenizer=tokenizer,
48
+ device=ModelConfig.device
49
+ )
50
+ print("Model loaded successfully!")
51
+ except Exception as e:
52
+ print(f"Error loading model: {str(e)}")
53
+ raise e
54
+
55
+ @app.post("/generate/", response_model=GenerateResponse)
56
+ async def generate_text(request: GenerateRequest):
57
+ if generator is None:
58
+ raise HTTPException(status_code=500, detail="Model not loaded")
59
+
60
+ try:
61
+ # Format the prompt with system prompt and chat history
62
+ formatted_prompt = f"{request.system_prompt}\n\n"
63
+ for msg in request.history:
64
+ formatted_prompt += f"{msg}\n"
65
+ formatted_prompt += f"Human: {request.prompt}\nAssistant:"
66
+
67
+ # Generate response
68
+ outputs = generator(
69
+ formatted_prompt,
70
+ max_length=request.max_length,
71
+ temperature=request.temperature,
72
+ num_return_sequences=1,
73
+ do_sample=True,
74
+ pad_token_id=tokenizer.pad_token_id,
75
+ eos_token_id=tokenizer.eos_token_id
76
  )
 
 
77
 
78
+ # Extract the generated text
79
+ generated_text = outputs[0]['generated_text']
80
+
81
+ # Remove the prompt from the response
82
+ response = generated_text.split("Assistant:")[-1].strip()
83
+
84
+ return {"response": response}
85
+
86
  except Exception as e:
87
+ raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
88
 
89
+ @app.get("/")
90
+ def root():
91
+ return {"message": "LLM API is running. Use /generate endpoint for text generation."}