Phoenix21 commited on
Commit
1ebaed9
·
verified ·
1 Parent(s): caa4cee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -8
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
  import torch
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
- from peft import PeftModel
7
  import uvicorn
8
-
9
- from huggingface_hub import login
10
 
11
  # Authenticate with Hugging Face Hub using the HF_TOKEN environment variable
12
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -29,14 +29,45 @@ base_model = AutoModelForCausalLM.from_pretrained(
29
  trust_remote_code=True
30
  )
31
 
32
- # Load the finetuned adapter using PEFT
33
- peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
34
- model = PeftModel.from_pretrained(base_model, peft_model_id)
35
-
36
  # Load the tokenizer from the base model
37
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
38
  tokenizer.pad_token = tokenizer.eos_token
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Create a text-generation pipeline using the loaded model and tokenizer
41
  chat_pipe = pipeline(
42
  "text-generation",
@@ -51,8 +82,11 @@ chat_pipe = pipeline(
51
  def generate(query: Query):
52
  prompt = f"Question: {query.text}\nAnswer: "
53
  response = chat_pipe(prompt)[0]["generated_text"]
 
 
 
54
  return {"response": response}
55
 
56
  if __name__ == "__main__":
57
  port = int(os.environ.get("PORT", 7860))
58
- uvicorn.run(app, host="0.0.0.0", port=port)
 
1
  import os
2
  import torch
3
+ import json
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from peft import PeftModel, PeftConfig
8
  import uvicorn
9
+ from huggingface_hub import login, hf_hub_download
 
10
 
11
  # Authenticate with Hugging Face Hub using the HF_TOKEN environment variable
12
  HF_TOKEN = os.getenv("HF_TOKEN")
 
29
  trust_remote_code=True
30
  )
31
 
 
 
 
 
32
  # Load the tokenizer from the base model
33
  tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
34
  tokenizer.pad_token = tokenizer.eos_token
35
 
36
+ # Load the finetuned adapter using PEFT with handling for eva_config
37
+ peft_model_id = "Phoenix21/llama-3-2-3b-finetuned-finance_checkpoint2"
38
+
39
+ try:
40
+ # Try direct loading first
41
+ model = PeftModel.from_pretrained(base_model, peft_model_id)
42
+ except TypeError as e:
43
+ if "eva_config" in str(e):
44
+ print("Handling eva_config compatibility issue...")
45
+ # Download config but handle it manually
46
+ config_path = hf_hub_download(repo_id=peft_model_id, filename="adapter_config.json")
47
+
48
+ with open(config_path, 'r') as f:
49
+ config_dict = json.load(f)
50
+
51
+ # Remove the problematic parameter
52
+ if 'eva_config' in config_dict:
53
+ del config_dict['eva_config']
54
+
55
+ # Save modified config
56
+ modified_config_path = "modified_adapter_config.json"
57
+ with open(modified_config_path, 'w') as f:
58
+ json.dump(config_dict, f)
59
+
60
+ # Load the config from the modified file
61
+ config = PeftConfig.from_json_file(modified_config_path)
62
+ # Ensure the config has the correct path
63
+ config._name_or_path = peft_model_id
64
+
65
+ # Now load with the modified config
66
+ model = PeftModel.from_pretrained(base_model, peft_model_id, config=config)
67
+ else:
68
+ # If it's a different error, raise it
69
+ raise
70
+
71
  # Create a text-generation pipeline using the loaded model and tokenizer
72
  chat_pipe = pipeline(
73
  "text-generation",
 
82
  def generate(query: Query):
83
  prompt = f"Question: {query.text}\nAnswer: "
84
  response = chat_pipe(prompt)[0]["generated_text"]
85
+ # Extract just the answer part from the response
86
+ if "Answer: " in response:
87
+ response = response.split("Answer: ", 1)[1]
88
  return {"response": response}
89
 
90
  if __name__ == "__main__":
91
  port = int(os.environ.get("PORT", 7860))
92
+ uvicorn.run(app, host="0.0.0.0", port=port)