barathm2001 commited on
Commit
fc5590b
·
verified ·
1 Parent(s): 6bc0593

Upload 4 files

Browse files
Files changed (1) hide show
  1. app.py +6 -12
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import logging
2
  from fastapi import FastAPI, HTTPException
3
- from transformers import AutoModelForCausalLM, pipeline
4
  from peft import PeftModel, PeftConfig
5
- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
6
- from mistral_common.client import MistralChain
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
@@ -16,11 +14,10 @@ app = FastAPI()
16
  model = None
17
  tokenizer = None
18
  pipe = None
19
- mistral_chain = None
20
 
21
  @app.on_event("startup")
22
  async def load_model():
23
- global model, tokenizer, pipe, mistral_chain
24
 
25
  try:
26
  logger.info("Loading PEFT configuration...")
@@ -33,10 +30,7 @@ async def load_model():
33
  model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
34
 
35
  logger.info("Loading tokenizer...")
36
- tokenizer = MistralTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
37
-
38
- logger.info("Creating MistralChain...")
39
- mistral_chain = MistralChain(model, tokenizer)
40
 
41
  logger.info("Creating pipeline...")
42
  pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
@@ -55,12 +49,12 @@ def home():
55
 
56
  @app.get("/generate")
57
  async def generate(text: str):
58
- if not mistral_chain:
59
  raise HTTPException(status_code=503, detail="Model not loaded")
60
 
61
  try:
62
- output = mistral_chain.generate(text, max_tokens=100)
63
- return {"output": output}
64
  except Exception as e:
65
  logger.error(f"Error during text generation: {e}")
66
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
 
1
  import logging
2
  from fastapi import FastAPI, HTTPException
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  from peft import PeftModel, PeftConfig
 
 
5
 
6
  # Set up logging
7
  logging.basicConfig(level=logging.INFO)
 
14
  model = None
15
  tokenizer = None
16
  pipe = None
 
17
 
18
  @app.on_event("startup")
19
  async def load_model():
20
+ global model, tokenizer, pipe
21
 
22
  try:
23
  logger.info("Loading PEFT configuration...")
 
30
  model = PeftModel.from_pretrained(base_model, "frankmorales2020/Mistral-7B-text-to-sql-flash-attention-2-dataeval")
31
 
32
  logger.info("Loading tokenizer...")
33
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
 
 
 
34
 
35
  logger.info("Creating pipeline...")
36
  pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
49
 
50
  @app.get("/generate")
51
  async def generate(text: str):
52
+ if not pipe:
53
  raise HTTPException(status_code=503, detail="Model not loaded")
54
 
55
  try:
56
+ output = pipe(text, max_length=100, num_return_sequences=1)
57
+ return {"output": output[0]['generated_text']}
58
  except Exception as e:
59
  logger.error(f"Error during text generation: {e}")
60
  raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")