Ashrafb commited on
Commit
00465c6
·
verified ·
1 Parent(s): 44ad5df

Rename app.py to main.py

Browse files
Files changed (1) hide show
  1. app.py → main.py +20 -55
app.py → main.py RENAMED
@@ -1162,68 +1162,33 @@ from threading import Thread
1162
  from transformers import TextIteratorStreamer
1163
  import hashlib
1164
  import os
 
 
 
 
 
 
 
1165
 
1166
  model_path = snapshot_download("vikhyatk/moondream1")
1167
  vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
1168
  text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)
1169
 
1170
 
1171
- def cached_vision_encoder(image):
1172
- # Calculate checksum of the image
1173
- image_hash = hashlib.sha256(image.tobytes()).hexdigest()
1174
 
1175
- # Check if `image_encoder_cache/{image_hash}.pt` exists, if so load and return it.
1176
- # Otherwise, save the encoded image to `image_encoder_cache/{image_hash}.pt` and return it.
1177
- cache_path = f"image_encoder_cache/{image_hash}.pt"
1178
- if os.path.exists(cache_path):
1179
- return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
1180
- else:
1181
- image_vec = vision_encoder(image).to("cpu", dtype=torch.float16)
1182
- os.makedirs("image_encoder_cache", exist_ok=True)
1183
- torch.save(image_vec, cache_path)
1184
- return image_vec.to(DEVICE, dtype=DTYPE)
1185
 
 
 
 
 
 
 
 
1186
 
1187
- def answer_question(image, question):
1188
- yield "Encoding image..."
1189
 
1190
- streamer = TextIteratorStreamer(text_model.tokenizer, skip_special_tokens=True)
1191
- generation_kwargs = dict(
1192
- image_embeds=cached_vision_encoder(image), question=question, streamer=streamer
1193
- )
1194
- thread = Thread(target=text_model.answer_question, kwargs=generation_kwargs)
1195
- thread.start()
1196
-
1197
- buffer = ""
1198
- for new_text in streamer:
1199
- buffer += new_text
1200
- if len(buffer) > 1:
1201
- yield re.sub("<$", "", re.sub("END$", "", buffer))
1202
-
1203
-
1204
- gr.Interface(
1205
- title="🌔 moondream1",
1206
- description="""
1207
- moondream1 is a tiny (1.6B parameter) vision language model trained by
1208
- <a href="https://x.com/vikhyatk">@vikhyatk</a> that performs on par with
1209
- models twice its size. It is trained on the LLaVa training dataset, and
1210
- initialized with SigLIP as the vision tower and Phi-1.5 as the text encoder.
1211
- Check out the <a href="https://huggingface.co/vikhyatk/moondream1">HuggingFace
1212
- model card</a> for more details.
1213
- """,
1214
- fn=answer_question,
1215
- inputs=[gr.Image(type="pil"), gr.Textbox(lines=2, label="Question")],
1216
- examples=[
1217
- [Image.open("assets/demo-1.jpg"), "Who is the author of this book?"],
1218
- [Image.open("assets/demo-2.jpg"), "What type of food is the girl eating?"],
1219
- [
1220
- Image.open("assets/demo-3.jpg"),
1221
- "What kind of public transportation is in the image?",
1222
- ],
1223
- [Image.open("assets/demo-4.jpg"), "What is the girl looking at?"],
1224
- [Image.open("assets/demo-5.jpg"), "What kind of dog is in the picture?"],
1225
- ],
1226
- outputs=gr.TextArea(label="Answer"),
1227
- allow_flagging="never",
1228
- cache_examples=False,
1229
- ).launch()
 
1162
  from transformers import TextIteratorStreamer
1163
  import hashlib
1164
  import os
1165
+ from fastapi import FastAPI, File, UploadFile, Form
1166
+ from PIL import Image
1167
+ from io import BytesIO
1168
+ from typing import List
1169
+ from pydantic import BaseModel
1170
+ from fastapi.responses import HTMLResponse, FileResponse
1171
+ from fastapi.staticfiles import StaticFiles
1172
 
1173
  model_path = snapshot_download("vikhyatk/moondream1")
1174
  vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
1175
  text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)
1176
 
1177
 
 
 
 
1178
 
1179
+ # Define a FastAPI app
1180
+ app = FastAPI()
 
 
 
 
 
 
 
 
1181
 
1182
+ # Define route for answering questions
1183
+ @app.post("/upload/")
1184
+ async def answer(image: UploadFile = File(...), Question: str = Form(...)):
1185
+ image_bytes = await image.read()
1186
+ image = Image.open(BytesIO(image_bytes))
1187
+ answer = answer_question(image, Question)
1188
+ return {"answer": answer}
1189
 
1190
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
1191
 
1192
+ @app.get("/")
1193
+ def index() -> FileResponse:
1194
+ return FileResponse(path="/app/static/index.html", media_type="text/html")