Charan5775 commited on
Commit
914bf49
·
verified ·
1 Parent(s): 731bac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -16
app.py CHANGED
@@ -1,22 +1,53 @@
1
- from fastapi import FastAPI, HTTPException
2
  from typing import Optional
3
  from fastapi.responses import StreamingResponse
4
  from huggingface_hub import InferenceClient
5
- from pydantic import BaseModel
6
  import os
7
- import uvicorn
8
-
 
 
9
 
10
  app = FastAPI()
11
 
 
 
 
 
 
12
 
13
  # Default model
14
  DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
15
 
16
  class QueryRequest(BaseModel):
 
 
17
  query: str
 
18
  stream: bool = False
19
- model_name: Optional[str] = None # If not provided, will use DEFAULT_MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def get_client(model_name: Optional[str] = None):
22
  """Get inference client for specified model or default model"""
@@ -25,7 +56,7 @@ def get_client(model_name: Optional[str] = None):
25
  model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
26
 
27
  return InferenceClient(
28
- model_path
29
  )
30
  except Exception as e:
31
  raise HTTPException(
@@ -33,12 +64,26 @@ def get_client(model_name: Optional[str] = None):
33
  detail=f"Error initializing model {model_path}: {str(e)}"
34
  )
35
 
36
- def generate_response(query: str, model_name: Optional[str] = None):
37
  messages = []
38
- messages.append({
39
- "role": "user",
40
- "content": f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
41
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  try:
44
  client = get_client(model_name)
@@ -57,17 +102,56 @@ async def root():
57
  return {"message": "Welcome to FastAPI server!"}
58
 
59
  @app.post("/chat")
60
- async def chat(request: QueryRequest):
 
61
  try:
62
- if request.stream:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  return StreamingResponse(
64
- generate_response(request.query, request.model_name),
65
  media_type="text/event-stream"
66
  )
67
  else:
68
  response = ""
69
- for chunk in generate_response(request.query, request.model_name):
70
  response += chunk
71
  return {"response": response}
72
  except Exception as e:
73
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends
2
  from typing import Optional
3
  from fastapi.responses import StreamingResponse
4
  from huggingface_hub import InferenceClient
5
+ from pydantic import BaseModel, ConfigDict
6
  import os
7
+ from base64 import b64encode
8
+ from io import BytesIO
9
+ from PIL import Image # Add this import
10
+ import logging
11
 
12
  app = FastAPI()
13
 
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.DEBUG)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Get HuggingFace token from environment variable
19
 
20
  # Default model
21
  DEFAULT_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
22
 
23
  class QueryRequest(BaseModel):
24
+ model_config = ConfigDict(protected_namespaces=())
25
+
26
  query: str
27
+ image_data: Optional[str] = None # Base64 encoded image data
28
  stream: bool = False
29
+ model_name: Optional[str] = None
30
+
31
+ class ChatForm(BaseModel):
32
+ model_config = ConfigDict(protected_namespaces=())
33
+
34
+ query: str
35
+ stream: bool = False
36
+ model_name: Optional[str] = None
37
+
38
+ @classmethod
39
+ def as_form(
40
+ cls,
41
+ query: str = Form(...),
42
+ stream: bool = Form(False),
43
+ model_name: Optional[str] = Form(None),
44
+ image: Optional[UploadFile] = File(None)
45
+ ):
46
+ return cls(
47
+ query=query,
48
+ stream=stream,
49
+ model_name=model_name
50
+ ), image
51
 
52
  def get_client(model_name: Optional[str] = None):
53
  """Get inference client for specified model or default model"""
 
56
  model_path = model_name if model_name and model_name.strip() else DEFAULT_MODEL
57
 
58
  return InferenceClient(
59
+ model=model_path
60
  )
61
  except Exception as e:
62
  raise HTTPException(
 
64
  detail=f"Error initializing model {model_path}: {str(e)}"
65
  )
66
 
67
+ def generate_response(query: str, image_data: Optional[str] = None, model_name: Optional[str] = None):
68
  messages = []
69
+
70
+ # Create the system and user message
71
+ user_content = f"[SYSTEM] You are ASSISTANT who answer question asked by user in short and concise manner. [USER] {query}"
72
+
73
+ # If there's an image, add it to the message
74
+ if image_data:
75
+ messages.append({
76
+ "role": "user",
77
+ "content": [
78
+ {"type": "text", "text": user_content},
79
+ {"type": "image_url", "image_url": {"url": f"data:image/*;base64,{image_data}"}}
80
+ ]
81
+ })
82
+ else:
83
+ messages.append({
84
+ "role": "user",
85
+ "content": user_content
86
+ })
87
 
88
  try:
89
  client = get_client(model_name)
 
102
  return {"message": "Welcome to FastAPI server!"}
103
 
104
  @app.post("/chat")
105
+ async def chat(form_data: tuple[ChatForm, Optional[UploadFile]] = Depends(ChatForm.as_form)):
106
+ form, image = form_data
107
  try:
108
+ image_data = None
109
+ if image:
110
+ logger.debug("Image received")
111
+ # Read the image
112
+ contents = await image.read()
113
+
114
+ # Convert image to appropriate format if needed
115
+ try:
116
+ logger.debug("Attempting to open image")
117
+ img = Image.open(BytesIO(contents))
118
+ logger.debug(f"Image format before conversion: {img.format}, mode: {img.mode}")
119
+ # Convert to RGB if needed
120
+ if img.mode != 'RGB':
121
+ img = img.convert('RGB')
122
+ logger.debug(f"Image format after conversion: {img.format}, mode: {img.mode}")
123
+
124
+ # Save as JPEG in memory
125
+ buffer = BytesIO()
126
+ img.save(buffer, format="JPEG")
127
+ image_data = b64encode(buffer.getvalue()).decode('utf-8')
128
+ logger.debug("Image processed and encoded to base64")
129
+ except Exception as img_error:
130
+ logger.error(f"Error processing image: {str(img_error)}")
131
+ raise HTTPException(
132
+ status_code=422,
133
+ detail=f"Error processing image: {str(img_error)}"
134
+ )
135
+
136
+ if form.stream:
137
  return StreamingResponse(
138
+ generate_response(form.query, image_data, form.model_name),
139
  media_type="text/event-stream"
140
  )
141
  else:
142
  response = ""
143
+ for chunk in generate_response(form.query, image_data, form.model_name):
144
  response += chunk
145
  return {"response": response}
146
  except Exception as e:
147
+ logger.error(f"Error in /chat endpoint: {str(e)}")
148
+ raise HTTPException(status_code=500, detail=str(e))
149
+
150
+ if __name__ == "__main__":
151
+ import uvicorn
152
+ uvicorn.run(
153
+ "main:app",
154
+ port=8000,
155
+ reload=True, # Enable auto-reload
156
+ reload_dirs=["./"] # Watch the current directory for changes
157
+ )