streetyogi commited on
Commit
62a4b51
·
1 Parent(s): d934794

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -3
main.py CHANGED
@@ -1,12 +1,43 @@
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
-
5
- from transformers import pipeline
6
 
7
  app = FastAPI()
8
 
9
- pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  @app.get("/infer_t5")
12
  def t5(input):
 
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
+ from transformers import T5Tokenizer, T5ForCausalLM, Trainer, TrainingArguments
 
5
 
6
  app = FastAPI()
7
 
8
+ # Initialize the tokenizer and model
9
+ tokenizer = T5Tokenizer.from_pretrained("t5-base")
10
+ model = T5ForCausalLM.from_pretrained("t5-base")
11
+
12
+ with open("cyberpunk_lore.txt", "r") as f:
13
+ dataset = f.read()
14
+
15
+ # Tokenize the dataset
16
+ input_ids = tokenizer.batch_encode_plus(dataset, return_tensors="pt")["input_ids"]
17
+
18
+ # Set up training arguments
19
+ training_args = TrainingArguments(
20
+ output_dir='./results',
21
+ overwrite_output_dir=True,
22
+ num_train_epochs=5,
23
+ per_device_train_batch_size=1,
24
+ save_steps=10_000,
25
+ save_total_limit=2,
26
+ )
27
+
28
+ # Create a Trainer
29
+ trainer = Trainer(
30
+ model=model,
31
+ args=training_args,
32
+ train_dataset=input_ids,
33
+ eval_dataset=input_ids
34
+ )
35
+
36
+ # Fine-tune the model
37
+ trainer.train()
38
+
39
+ # Create the inference pipeline
40
+ pipe_flan = pipeline("text2text-generation", model=model)
41
 
42
  @app.get("/infer_t5")
43
  def t5(input):