BeveledCube commited on
Commit
e1b04d1
·
verified ·
1 Parent(s): 054e3fb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -15
main.py CHANGED
@@ -3,24 +3,14 @@ from fastapi.responses import FileResponse
3
  from pydantic import BaseModel
4
  from fastapi import FastAPI
5
 
6
- import os
7
 
8
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer
9
- import torch
10
 
11
  app = FastAPI()
12
- name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
13
- customGen = False
14
-
15
- # microsoft/DialoGPT-small
16
- # microsoft/DialoGPT-medium
17
- # microsoft/DialoGPT-large
18
-
19
- # mistralai/Mixtral-8x7B-Instruct-v0.1
20
-
21
- # Load the Hugging Face GPT-2 model and tokenizer
22
- model = AutoModelForCausalLM.from_pretrained(name)
23
- tokenizer = AutoTokenizer.from_pretrained(name)
24
 
25
  class req(BaseModel):
26
  prompt: str
@@ -34,6 +24,7 @@ def read_root():
34
  def read_root(data: req):
35
  print("Prompt:", data.prompt)
36
  print("Length:", data.length)
 
37
 
38
  input_text = data.prompt
39
 
 
3
  from pydantic import BaseModel
4
  from fastapi import FastAPI
5
 
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
 
8
+ model_name = "facebook/blenderbot-400M-distill"
9
+ # facebook/blenderbot-400M-distill
10
 
11
  app = FastAPI()
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
14
 
15
  class req(BaseModel):
16
  prompt: str
 
24
  def read_root(data: req):
25
  print("Prompt:", data.prompt)
26
  print("Length:", data.length)
27
+ print("Generating")
28
 
29
  input_text = data.prompt
30