Phoenix21 commited on
Commit
b6332c3
·
verified ·
1 Parent(s): b4957d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -15
app.py CHANGED
@@ -3,10 +3,10 @@ import torch
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
- from peft import PeftModel, PeftConfig
7
  import uvicorn
8
-
9
- from huggingface_hub import login
10
 
11
  # Authenticate with Hugging Face Hub
12
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -29,17 +29,41 @@ model = AutoModelForCausalLM.from_pretrained(
29
  trust_remote_code=True
30
  )
31
 
32
- # Load adapter from your checkpoint with a workaround for 'eva_config'
33
  peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
34
- # Load the PEFT configuration first
35
- peft_config = PeftConfig.from_pretrained(peft_model_id)
36
- # Remove 'eva_config' if it exists in the configuration
37
- peft_config_dict = peft_config.to_dict()
38
- if "eva_config" in peft_config_dict:
39
- peft_config_dict.pop("eva_config")
40
- peft_config = PeftConfig.from_dict(peft_config_dict)
41
- # Load the adapter using the filtered configuration
42
- model = PeftModel.from_pretrained(model, peft_model_id, config=peft_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Load tokenizer from the base model
45
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
@@ -59,8 +83,10 @@ chat_pipe = pipeline(
59
  def generate(query: Query):
60
  prompt = f"Question: {query.text}\nAnswer: "
61
  response = chat_pipe(prompt)[0]["generated_text"]
62
- return {"response": response}
 
 
63
 
64
  if __name__ == "__main__":
65
  port = int(os.environ.get("PORT", 7860))
66
- uvicorn.run(app, host="0.0.0.0", port=port)
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+ from peft import PeftModel
7
  import uvicorn
8
+ import json
9
+ from huggingface_hub import hf_hub_download, login
10
 
11
  # Authenticate with Hugging Face Hub
12
  HF_TOKEN = os.getenv("HF_TOKEN")
 
29
  trust_remote_code=True
30
  )
31
 
32
+ # Load adapter from your checkpoint with a fix for the 'eva_config' issue
33
  peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
34
+
35
+ # Manually download and load the adapter config to filter out problematic fields
36
+ try:
37
+ # Download the adapter_config.json file
38
+ config_file = hf_hub_download(
39
+ repo_id=peft_model_id,
40
+ filename="adapter_config.json",
41
+ token=HF_TOKEN
42
+ )
43
+
44
+ # Load and clean the config
45
+ with open(config_file, 'r') as f:
46
+ config_dict = json.load(f)
47
+
48
+ # Remove problematic fields if they exist
49
+ if "eva_config" in config_dict:
50
+ del config_dict["eva_config"]
51
+
52
+ # Load the adapter directly with the cleaned config
53
+ model = PeftModel.from_pretrained(
54
+ model,
55
+ peft_model_id,
56
+ config=config_dict
57
+ )
58
+ except Exception as e:
59
+ print(f"Error loading adapter: {e}")
60
+ # Fallback to direct loading if the above fails
61
+ model = PeftModel.from_pretrained(
62
+ model,
63
+ peft_model_id,
64
+ # Use this config parameter to ignore unknown parameters
65
+ config=None
66
+ )
67
 
68
  # Load tokenizer from the base model
69
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
 
83
  def generate(query: Query):
84
  prompt = f"Question: {query.text}\nAnswer: "
85
  response = chat_pipe(prompt)[0]["generated_text"]
86
+ # Extract only the answer part from the response
87
+ answer = response.split("Answer: ")[-1].strip()
88
+ return {"response": answer}
89
 
90
  if __name__ == "__main__":
91
  port = int(os.environ.get("PORT", 7860))
92
+ uvicorn.run(app, host="0.0.0.0", port=port)