Spaces:
Runtime error
Runtime error
llinahosna
commited on
Commit
•
34dc018
1
Parent(s):
36cabf5
Create dall_e.py
Browse files
dall_e.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, getpass
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import replicate
|
5 |
+
|
6 |
+
|
7 |
+
class DalleImageGenerator:
|
8 |
+
"""Dall-e model using Replicate API"""
|
9 |
+
def __init__(self, token=None):
|
10 |
+
if "REPLICATE_API_TOKEN" not in os.environ:
|
11 |
+
if token is not None:
|
12 |
+
os.environ["REPLICATE_API_TOKEN"] = token
|
13 |
+
else:
|
14 |
+
print(f"Please go to https://replicate.com/docs/api for your Replicate API token.")
|
15 |
+
os.environ["REPLICATE_API_TOKEN"] = getpass.getpass(f"Input Replicate API Token:")
|
16 |
+
|
17 |
+
self.dalle = replicate.models.get("kuprel/min-dalle")
|
18 |
+
|
19 |
+
def generate_images(self, text, grid_size, text_adherence=2):
|
20 |
+
urls = self.dalle.predict(text=text, grid_size=grid_size, log2_supercondition_factor=text_adherence)
|
21 |
+
images = get_image(list(urls)[-1])
|
22 |
+
h, w = images.shape[:2]
|
23 |
+
h, w = h // grid_size, w // grid_size
|
24 |
+
return blockshaped(images, h, w)
|
25 |
+
|
26 |
+
|
27 |
+
def get_image(url):
|
28 |
+
"""download image from a url"""
|
29 |
+
from urllib.request import Request, urlopen
|
30 |
+
import io
|
31 |
+
import PIL.Image as Image
|
32 |
+
hdr = {
|
33 |
+
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.11 (KHTML, like Gecko) Chrome/23.0.1271.64 Safari/537.11',
|
34 |
+
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
|
35 |
+
'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
|
36 |
+
'Accept-Encoding': 'none',
|
37 |
+
'Accept-Language': 'en-US,en;q=0.8',
|
38 |
+
'Connection': 'keep-alive'}
|
39 |
+
|
40 |
+
# urllib.request.urlretrieve(url, f"local-filename.jpg")
|
41 |
+
req = Request(url, headers=hdr)
|
42 |
+
page = urlopen(req)
|
43 |
+
return np.array(Image.open(io.BytesIO(page.read())))
|
44 |
+
|
45 |
+
|
46 |
+
def blockshaped(arr, nrows, ncols):
|
47 |
+
"""
|
48 |
+
Return an array of shape (n, nrows, ncols) where
|
49 |
+
n * nrows * ncols = arr.size
|
50 |
+
|
51 |
+
If arr is a 2D array, the returned array should look like n subblocks with
|
52 |
+
each subblock preserving the "physical" layout of arr.
|
53 |
+
"""
|
54 |
+
h, w, c = arr.shape
|
55 |
+
assert h % nrows == 0, f"{h} rows is not evenly divisible by {nrows}"
|
56 |
+
assert w % ncols == 0, f"{w} cols is not evenly divisible by {ncols}"
|
57 |
+
return (arr.reshape(h//nrows, nrows, w//ncols, ncols, - 1)
|
58 |
+
.swapaxes(1,2)
|
59 |
+
.reshape(-1, nrows, ncols, 3))
|