hysts HF staff commited on
Commit
69fd992
·
1 Parent(s): e5ba201

Refactor process_new_user_message to simplify file handling and improve readability

Browse files
Files changed (1) hide show
  1. app.py +38 -37
app.py CHANGED
@@ -42,43 +42,44 @@ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
42
 
43
 
44
  def process_new_user_message(message: dict) -> list[dict]:
45
- if message["files"]:
46
- if "<image>" in message["text"]:
47
- content = []
48
- logger.debug(f"{message['files']=}")
49
- parts = re.split(r"(<image>)", message["text"])
50
- image_index = 0
51
- logger.debug(f"{parts=}")
52
- for part in parts:
53
- logger.debug(f"{part=}")
54
- if part == "<image>":
55
- content.append({"type": "image", "url": message["files"][image_index]})
56
- logger.debug(f"file: {message['files'][image_index]}")
57
- image_index += 1
58
- elif part.strip():
59
- content.append({"type": "text", "text": part.strip()})
60
- elif isinstance(part, str) and part != "<image>":
61
- content.append({"type": "text", "text": part})
62
- logger.debug(f"{content=}")
63
- return content
64
- if message["files"][0].endswith(".mp4"):
65
- content = []
66
- video = message["files"].pop(0)
67
- frames = downsample_video(video)
68
- for frame in frames:
69
- pil_image, timestamp = frame
70
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
71
- pil_image.save(temp_file.name)
72
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
73
- content.append({"type": "image", "url": temp_file.name})
74
- logger.debug(f"{content=}")
75
- return content
76
- # non interleaved images
77
- return [
78
- {"type": "text", "text": message["text"]},
79
- *[{"type": "image", "url": path} for path in message["files"]],
80
- ]
81
- return [{"type": "text", "text": message["text"]}]
 
82
 
83
 
84
  def process_history(history: list[dict]) -> list[dict]:
 
42
 
43
 
44
  def process_new_user_message(message: dict) -> list[dict]:
45
+ if not message["files"]:
46
+ return [{"type": "text", "text": message["text"]}]
47
+
48
+ if "<image>" in message["text"]:
49
+ content = []
50
+ logger.debug(f"{message['files']=}")
51
+ parts = re.split(r"(<image>)", message["text"])
52
+ image_index = 0
53
+ logger.debug(f"{parts=}")
54
+ for part in parts:
55
+ logger.debug(f"{part=}")
56
+ if part == "<image>":
57
+ content.append({"type": "image", "url": message["files"][image_index]})
58
+ logger.debug(f"file: {message['files'][image_index]}")
59
+ image_index += 1
60
+ elif part.strip():
61
+ content.append({"type": "text", "text": part.strip()})
62
+ elif isinstance(part, str) and part != "<image>":
63
+ content.append({"type": "text", "text": part})
64
+ logger.debug(f"{content=}")
65
+ return content
66
+ if message["files"][0].endswith(".mp4"):
67
+ content = []
68
+ video = message["files"].pop(0)
69
+ frames = downsample_video(video)
70
+ for frame in frames:
71
+ pil_image, timestamp = frame
72
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
73
+ pil_image.save(temp_file.name)
74
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
75
+ content.append({"type": "image", "url": temp_file.name})
76
+ logger.debug(f"{content=}")
77
+ return content
78
+ # non interleaved images
79
+ return [
80
+ {"type": "text", "text": message["text"]},
81
+ *[{"type": "image", "url": path} for path in message["files"]],
82
+ ]
83
 
84
 
85
  def process_history(history: list[dict]) -> list[dict]: