Rename app.py to main.py
Browse files- app.py → main.py +20 -55
app.py → main.py
RENAMED
@@ -1162,68 +1162,33 @@ from threading import Thread
|
|
1162 |
from transformers import TextIteratorStreamer
|
1163 |
import hashlib
|
1164 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1165 |
|
1166 |
model_path = snapshot_download("vikhyatk/moondream1")
|
1167 |
vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
|
1168 |
text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)
|
1169 |
|
1170 |
|
1171 |
-
def cached_vision_encoder(image):
|
1172 |
-
# Calculate checksum of the image
|
1173 |
-
image_hash = hashlib.sha256(image.tobytes()).hexdigest()
|
1174 |
|
1175 |
-
|
1176 |
-
|
1177 |
-
cache_path = f"image_encoder_cache/{image_hash}.pt"
|
1178 |
-
if os.path.exists(cache_path):
|
1179 |
-
return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
|
1180 |
-
else:
|
1181 |
-
image_vec = vision_encoder(image).to("cpu", dtype=torch.float16)
|
1182 |
-
os.makedirs("image_encoder_cache", exist_ok=True)
|
1183 |
-
torch.save(image_vec, cache_path)
|
1184 |
-
return image_vec.to(DEVICE, dtype=DTYPE)
|
1185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1186 |
|
1187 |
-
|
1188 |
-
yield "Encoding image..."
|
1189 |
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
)
|
1194 |
-
thread = Thread(target=text_model.answer_question, kwargs=generation_kwargs)
|
1195 |
-
thread.start()
|
1196 |
-
|
1197 |
-
buffer = ""
|
1198 |
-
for new_text in streamer:
|
1199 |
-
buffer += new_text
|
1200 |
-
if len(buffer) > 1:
|
1201 |
-
yield re.sub("<$", "", re.sub("END$", "", buffer))
|
1202 |
-
|
1203 |
-
|
1204 |
-
gr.Interface(
|
1205 |
-
title="🌔 moondream1",
|
1206 |
-
description="""
|
1207 |
-
moondream1 is a tiny (1.6B parameter) vision language model trained by
|
1208 |
-
<a href="https://x.com/vikhyatk">@vikhyatk</a> that performs on par with
|
1209 |
-
models twice its size. It is trained on the LLaVa training dataset, and
|
1210 |
-
initialized with SigLIP as the vision tower and Phi-1.5 as the text encoder.
|
1211 |
-
Check out the <a href="https://huggingface.co/vikhyatk/moondream1">HuggingFace
|
1212 |
-
model card</a> for more details.
|
1213 |
-
""",
|
1214 |
-
fn=answer_question,
|
1215 |
-
inputs=[gr.Image(type="pil"), gr.Textbox(lines=2, label="Question")],
|
1216 |
-
examples=[
|
1217 |
-
[Image.open("assets/demo-1.jpg"), "Who is the author of this book?"],
|
1218 |
-
[Image.open("assets/demo-2.jpg"), "What type of food is the girl eating?"],
|
1219 |
-
[
|
1220 |
-
Image.open("assets/demo-3.jpg"),
|
1221 |
-
"What kind of public transportation is in the image?",
|
1222 |
-
],
|
1223 |
-
[Image.open("assets/demo-4.jpg"), "What is the girl looking at?"],
|
1224 |
-
[Image.open("assets/demo-5.jpg"), "What kind of dog is in the picture?"],
|
1225 |
-
],
|
1226 |
-
outputs=gr.TextArea(label="Answer"),
|
1227 |
-
allow_flagging="never",
|
1228 |
-
cache_examples=False,
|
1229 |
-
).launch()
|
|
|
1162 |
from transformers import TextIteratorStreamer
|
1163 |
import hashlib
|
1164 |
import os
|
1165 |
+
from fastapi import FastAPI, File, UploadFile, Form
|
1166 |
+
from PIL import Image
|
1167 |
+
from io import BytesIO
|
1168 |
+
from typing import List
|
1169 |
+
from pydantic import BaseModel
|
1170 |
+
from fastapi.responses import HTMLResponse, FileResponse
|
1171 |
+
from fastapi.staticfiles import StaticFiles
|
1172 |
|
1173 |
model_path = snapshot_download("vikhyatk/moondream1")
|
1174 |
vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
|
1175 |
text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)
|
1176 |
|
1177 |
|
|
|
|
|
|
|
1178 |
|
1179 |
+
# Define a FastAPI app
|
1180 |
+
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1181 |
|
1182 |
+
# Define route for answering questions
|
1183 |
+
@app.post("/upload/")
|
1184 |
+
async def answer(image: UploadFile = File(...), Question: str = Form(...)):
|
1185 |
+
image_bytes = await image.read()
|
1186 |
+
image = Image.open(BytesIO(image_bytes))
|
1187 |
+
answer = answer_question(image, Question)
|
1188 |
+
return {"answer": answer}
|
1189 |
|
1190 |
+
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
|
|
1191 |
|
1192 |
+
@app.get("/")
|
1193 |
+
def index() -> FileResponse:
|
1194 |
+
return FileResponse(path="/app/static/index.html", media_type="text/html")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|