Abhisesh7 commited on
Commit
26dda45
·
verified ·
1 Parent(s): 31d47f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import DallEProcessor, DallEModel
3
  from PIL import Image
4
  import requests
5
  from flask import Flask, request, jsonify
@@ -8,9 +8,9 @@ import io
8
  app = Flask(__name__)
9
 
10
  # Initialize the DALL-E mini model and processor
11
- model_name = "dalle-mini/dalle-mini"
12
- processor = DallEProcessor.from_pretrained(model_name)
13
- model = DallEModel.from_pretrained(model_name)
14
 
15
  @app.route('/generate', methods=['POST'])
16
  def generate_image():
@@ -19,11 +19,11 @@ def generate_image():
19
 
20
  # Generate images from prompt
21
  inputs = processor(text=[prompt], return_tensors="pt")
22
- outputs = model(**inputs)
23
 
24
  # Post-process the generated image
25
- generated_image = outputs.logits.argmax(-1)[0]
26
- image = Image.fromarray(generated_image)
27
 
28
  # Save image to a BytesIO object
29
  img_byte_arr = io.BytesIO()
 
1
  import torch
2
+ from transformers import DALLMiniProcessor, DALLMiniModel
3
  from PIL import Image
4
  import requests
5
  from flask import Flask, request, jsonify
 
8
  app = Flask(__name__)
9
 
10
  # Initialize the DALL-E mini model and processor
11
+ model_name = "flax-community/dalle-mini"
12
+ processor = DALLMiniProcessor.from_pretrained(model_name)
13
+ model = DALLMiniModel.from_pretrained(model_name)
14
 
15
  @app.route('/generate', methods=['POST'])
16
  def generate_image():
 
19
 
20
  # Generate images from prompt
21
  inputs = processor(text=[prompt], return_tensors="pt")
22
+ outputs = model.generate(**inputs)
23
 
24
  # Post-process the generated image
25
+ generated_image = outputs[0]
26
+ image = Image.fromarray(generated_image.numpy().astype('uint8'))
27
 
28
  # Save image to a BytesIO object
29
  img_byte_arr = io.BytesIO()