Empereur-Pirate commited on
Commit
03ee1c6
·
verified ·
1 Parent(s): 79e06e3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +48 -44
main.py CHANGED
@@ -5,63 +5,67 @@ from transformers import pipeline
5
  from pydantic import BaseModel
6
  from typing import Optional, Any
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
9
 
10
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Initialize device
13
- def initialize_device():
14
- global device
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
16
 
17
- initialize_device()
 
 
 
 
 
 
 
18
 
19
  # Helper function to read raw request bodies
20
  async def parse_raw(request: Request):
21
  return await request.body()
22
 
23
- # Initialize the model and tokenizer with the corrected pre-trained weights
24
- def init_corrected_model():
25
- global model_config, model, tokenizer
26
 
27
- try:
28
- model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
29
- model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
30
- tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
31
- except Exception as e:
32
- print("[WARNING]: Failed to load model and tokenizer conventionally.")
33
- print(f"Exception: {e}")
 
 
 
34
 
35
- model_config = AutoConfig.from_pretrained("152334H/miqu-1-70b-sf", trust_remote_code=True)
36
- model = AutoModelForCausalLM.from_pretrained("152334H/miqu-1-70b-sf", config=model_config).to(device)
37
- tokenizer = AutoTokenizer.from_pretrained("152334H/miqu-1-70b-sf")
38
 
39
- init_corrected_model()
40
-
41
- # Utility function to generate answers from the model
42
- def miuk_answer(query: str) -> dict[str, str]:
43
- query_tokens = tokenizer.encode(query, return_tensors="pt")
44
- query_tokens = query_tokens.to(device)
45
- answer = model.generate(query_tokens, max_length=128, temperature=1, pad_token_id=tokenizer.pad_token_id)
46
- return {"output": tokenizer.decode(answer[:, 0])}
47
-
48
- # Endpoint handler to receive incoming queries and pass them to the utility function for processing
49
- @app.post("/infer_miku")
50
- async def infer_endpoint(data: BaseModel = Depends(parse_raw)):
51
  input_text = data.raw.decode("utf-8")
52
-
53
- if input_text is None or len(input_text) == 0:
54
  return JSONResponse({"error": "Empty input received."}, status_code=400)
55
 
56
- result = miuk_answer(input_text)
57
- return result
58
-
59
- @app.get("/infer_miku")
60
- def get_default_inference_endpoint():
61
- return {"message": "Use POST method to submit input data"}
62
 
63
  # Mount static files
64
- app.mount("/static", StaticFiles(directory="static"), name="static")
65
-
66
- # Initialization done
67
- print("Initialization Complete.")
 
5
  from pydantic import BaseModel
6
  from typing import Optional, Any
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, GenerationConfig
9
 
10
+ # Authentication
11
+ from huggingface_hub import interpreter_login
12
+ interpreter_login()
13
+
14
+ # Packages and model loading
15
+ import torch
16
+ base_model_id = "152334H/miqu-1-70b-sf"
17
+ bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.bfloat16
22
+ )
23
 
24
+ base_model = AutoModelForCausalLM.from_pretrained(
25
+ base_model_id,
26
+ quantization_config=bnb_config,
27
+ device_map="auto",
28
+ trust_remote_code=True,
29
+ )
30
 
31
+ # Tokenizer loading
32
+ eval_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf", add_bos_token=True, trust_remote_code=True, use_auth_token=True)
33
+
34
+ # Streamer
35
+ streamer = TextStreamer(eval_tokenizer)
36
+
37
+ # App definition
38
+ app = FastAPI()
39
 
40
  # Helper function to read raw request bodies
41
  async def parse_raw(request: Request):
42
  return await request.body()
43
 
44
+ # Generate text
45
+ def generate_text(prompt: str) -> str:
46
+ model_input = eval_tokenizer(prompt, return_tensors="pt").to("cuda")
47
 
48
+ base_model.eval()
49
+ with torch.no_grad():
50
+ generated_sequences = base_model.generate(
51
+ **model_input,
52
+ max_new_tokens=4096,
53
+ repetition_penalty=1.1,
54
+ do_sample=True,
55
+ temperature=1,
56
+ streamer=streamer,
57
+ )
58
 
59
+ return eval_tokenizer.decode(generated_sequences[0], skip_special_tokens=True)
 
 
60
 
61
+ # Route for generating text
62
+ @app.post("/generate_text")
63
+ async def generate_text_route(data: BaseModel = Depends(parse_raw)):
 
 
 
 
 
 
 
 
 
64
  input_text = data.raw.decode("utf-8")
65
+ if not input_text or len(input_text) <= 0:
 
66
  return JSONResponse({"error": "Empty input received."}, status_code=400)
67
 
68
+ return {"output": generate_text(input_text)}
 
 
 
 
 
69
 
70
  # Mount static files
71
+ app.mount("/static", StaticFiles(directory="static"), name="static")