Phoenix21 commited on
Commit
028f06a
·
verified ·
1 Parent(s): 5cf6d65

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -0
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from fastapi import FastAPI
4
+ from pydantic import BaseModel
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ import uvicorn
7
+
8
+ # Define a Pydantic model for request validation
9
+ class Query(BaseModel):
10
+ text: str
11
+
12
+ # Initialize FastAPI app
13
+ app = FastAPI(title="Financial Chatbot API")
14
+
15
+ # Load your fine-tuned model and tokenizer
16
+ model_name = "Phoenix21/llama-3-2-3b-finetuned-finance"
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
23
+ tokenizer.pad_token = tokenizer.eos_token
24
+
25
+ # Create a text-generation pipeline
26
+ chat_pipe = pipeline(
27
+ "text-generation",
28
+ model=model,
29
+ tokenizer=tokenizer,
30
+ max_new_tokens=256,
31
+ temperature=0.7,
32
+ top_p=0.95,
33
+ )
34
+
35
+ # Define an endpoint for generating responses
36
+ @app.post("/generate")
37
+ def generate(query: Query):
38
+ prompt = f"Question: {query.text}\nAnswer: "
39
+ response = chat_pipe(prompt)[0]["generated_text"]
40
+ return {"response": response}
41
+
42
+ # Run the app using uvicorn. Hugging Face Spaces sets the PORT environment variable.
43
+ if __name__ == "__main__":
44
+ port = int(os.environ.get("PORT", 8000))
45
+ uvicorn.run(app, host="0.0.0.0", port=port)