homer-meng commited on
Commit
7c373db
·
1 Parent(s): 54d6e13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -1,21 +1,17 @@
1
- from typing import Optional
2
- import uvicorn
3
- from fastapi import FastAPI, Form, File, UploadFile
4
  from transformers import AutoTokenizer, GPTNeoForCausalLM
5
  from PIL import Image
6
  import io
7
 
8
- app = FastAPI()
9
-
10
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/sd-1.5")
11
  model = GPTNeoForCausalLM.from_pretrained("EleutherAI/sd-1.5")
12
 
13
- @app.post("/generate_drawing/")
14
- async def generate_drawing(prompt: str = Form(...)):
 
 
15
  inputs = tokenizer(prompt, return_tensors="pt")
16
  outputs = model.generate(inputs['input_ids'], max_length=256, do_sample=True)
17
  image = Image.open(io.BytesIO(outputs[0].cpu().numpy()))
18
- return {"image": image}
19
-
20
- if __name__ == "__main__":
21
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ import streamlit as st
2
+ import torch
 
3
  from transformers import AutoTokenizer, GPTNeoForCausalLM
4
  from PIL import Image
5
  import io
6
 
 
 
7
  tokenizer = AutoTokenizer.from_pretrained("EleutherAI/sd-1.5")
8
  model = GPTNeoForCausalLM.from_pretrained("EleutherAI/sd-1.5")
9
 
10
+ st.title("Scribble Drawing Generator")
11
+
12
+ prompt = st.text_input("Enter a prompt:")
13
+ if prompt:
14
  inputs = tokenizer(prompt, return_tensors="pt")
15
  outputs = model.generate(inputs['input_ids'], max_length=256, do_sample=True)
16
  image = Image.open(io.BytesIO(outputs[0].cpu().numpy()))
17
+ st.image(image, caption="Generated Drawing")