Phoenix21 commited on
Commit
bc3abd0
·
verified ·
1 Parent(s): 015495a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -46
app.py CHANGED
@@ -5,10 +5,10 @@ from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
  from peft import PeftModel
7
  import uvicorn
8
- import json
9
- from huggingface_hub import hf_hub_download, login
10
 
11
- # Authenticate with Hugging Face Hub
 
 
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  if HF_TOKEN:
14
  login(token=HF_TOKEN)
@@ -21,55 +21,23 @@ class Query(BaseModel):
21
 
22
  app = FastAPI(title="Financial Chatbot API")
23
 
24
- # Load the base model
25
- base_model_name = "meta-llama/Llama-3.2-3B" # Update if using a different base model
26
- model = AutoModelForCausalLM.from_pretrained(
27
  base_model_name,
28
  device_map="auto",
29
  trust_remote_code=True
30
  )
31
 
32
- # Load adapter from your checkpoint with a fix for the 'eva_config' issue
33
  peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
 
34
 
35
- # Manually download and load the adapter config to filter out problematic fields
36
- try:
37
- # Download the adapter_config.json file
38
- config_file = hf_hub_download(
39
- repo_id=peft_model_id,
40
- filename="adapter_config.json",
41
- token=HF_TOKEN
42
- )
43
-
44
- # Load and clean the config
45
- with open(config_file, 'r') as f:
46
- config_dict = json.load(f)
47
-
48
- # Remove problematic fields if they exist
49
- if "eva_config" in config_dict:
50
- del config_dict["eva_config"]
51
-
52
- # Load the adapter directly with the cleaned config
53
- model = PeftModel.from_pretrained(
54
- model,
55
- peft_model_id,
56
- config=config_dict
57
- )
58
- except Exception as e:
59
- print(f"Error loading adapter: {e}")
60
- # Fallback to direct loading if the above fails
61
- model = PeftModel.from_pretrained(
62
- model,
63
- peft_model_id,
64
- # Use this config parameter to ignore unknown parameters
65
- config=None
66
- )
67
-
68
- # Load tokenizer from the base model
69
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
70
  tokenizer.pad_token = tokenizer.eos_token
71
 
72
- # Create a text-generation pipeline
73
  chat_pipe = pipeline(
74
  "text-generation",
75
  model=model,
@@ -83,10 +51,8 @@ chat_pipe = pipeline(
83
  def generate(query: Query):
84
  prompt = f"Question: {query.text}\nAnswer: "
85
  response = chat_pipe(prompt)[0]["generated_text"]
86
- # Extract only the answer part from the response
87
- answer = response.split("Answer: ")[-1].strip()
88
- return {"response": answer}
89
 
90
  if __name__ == "__main__":
91
  port = int(os.environ.get("PORT", 7860))
92
- uvicorn.run(app, host="0.0.0.0", port=port)
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
  from peft import PeftModel
7
  import uvicorn
 
 
8
 
9
+ from huggingface_hub import login
10
+
11
+ # Authenticate with Hugging Face Hub using the HF_TOKEN environment variable
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  if HF_TOKEN:
14
  login(token=HF_TOKEN)
 
21
 
22
  app = FastAPI(title="Financial Chatbot API")
23
 
24
+ # Load the base model from Meta-Llama
25
+ base_model_name = "meta-llama/Llama-3.2-3B"
26
+ base_model = AutoModelForCausalLM.from_pretrained(
27
  base_model_name,
28
  device_map="auto",
29
  trust_remote_code=True
30
  )
31
 
32
+ # Load the finetuned adapter using PEFT
33
  peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
34
+ model = PeftModel.from_pretrained(base_model, peft_model_id)
35
 
36
+ # Load the tokenizer from the base model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
38
  tokenizer.pad_token = tokenizer.eos_token
39
 
40
+ # Create a text-generation pipeline using the loaded model and tokenizer
41
  chat_pipe = pipeline(
42
  "text-generation",
43
  model=model,
 
51
  def generate(query: Query):
52
  prompt = f"Question: {query.text}\nAnswer: "
53
  response = chat_pipe(prompt)[0]["generated_text"]
54
+ return {"response": response}
 
 
55
 
56
  if __name__ == "__main__":
57
  port = int(os.environ.get("PORT", 7860))
58
+ uvicorn.run(app, host="0.0.0.0", port=port)