DR-Rakshitha commited on
Commit
c12b438
·
1 Parent(s): 8272482

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -18,23 +18,23 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
18
  # pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
19
  # result = pipe(f"<s>[INST] {input_text} [/INST]")
20
  # return result[0]['generated_text']
21
-
22
  from transformers import AutoModelForCausalLM, AutoTokenizer
23
- from fastapi import FastAPI
24
 
25
- app = FastAPI()
 
 
26
 
27
- model_name = "pytorch_model-00001-of-00002.bin" # Replace with your Hugging Face model name
 
 
28
 
29
- model = AutoModelForCausalLM.from_pretrained(model_name)
30
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
31
 
32
- @app.post("/generate/")
33
- async def generate_text(prompt: str):
34
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
35
- output = model.generate(input_ids, max_length=50, num_return_sequences=1)
36
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
37
- return {"generated_text": generated_text}
38
 
39
 
40
  text_generation_interface = gr.Interface(
 
18
  # pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
19
  # result = pipe(f"<s>[INST] {input_text} [/INST]")
20
  # return result[0]['generated_text']
 
21
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
22
 
23
+ # Specify the path to your fine-tuned model and tokenizer
24
+ model_path = "./" # Assuming the model is in the same directory as your notebook
25
+ model_name = "pytorch_model-00001-of-00002.bin" # Replace with your model name
26
 
27
+ # Load the model and tokenizer
28
+ model = AutoModelForCausalLM.from_pretrained(model_path)
29
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
30
 
31
+ # Example usage
32
+ input_text = "Once upon a time"
33
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
34
+ output = model.generate(input_ids, max_length=50, num_return_sequences=1)
35
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
36
 
37
+ print(generated_text)
 
 
 
 
 
38
 
39
 
40
  text_generation_interface = gr.Interface(