From 5c4ce5392c9e8e19dbeafa9df6a4e6906ee9dffb Mon Sep 17 00:00:00 2001 From: Sami Khan <98742866+samiamjidkhan@users.noreply.github.com> Date: Tue, 21 Jan 2025 04:33:54 -0500 Subject: [PATCH] image and text mode fix --- exo/api/chatgpt_api.py | 67 ++++++++++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 19 deletions(-) diff --git a/exo/api/chatgpt_api.py b/exo/api/chatgpt_api.py index d3e17803..acfb3b33 100644 --- a/exo/api/chatgpt_api.py +++ b/exo/api/chatgpt_api.py @@ -139,11 +139,23 @@ def remap_messages(messages: List[Message]) -> List[Message]: def build_prompt(tokenizer, _messages: List[Message], tools: Optional[List[Dict]] = None): messages = remap_messages(_messages) chat_template_args = {"conversation": [m.to_dict() for m in messages], "tokenize": False, "add_generation_prompt": True} - if tools: chat_template_args["tools"] = tools + if tools: + chat_template_args["tools"] = tools - prompt = tokenizer.apply_chat_template(**chat_template_args) - print(f"!!! Prompt: {prompt}") - return prompt + try: + prompt = tokenizer.apply_chat_template(**chat_template_args) + if DEBUG >= 3: print(f"!!! Prompt: {prompt}") + return prompt + except UnicodeEncodeError: + # Handle Unicode encoding by ensuring everything is UTF-8 + chat_template_args["conversation"] = [ + {k: v.encode('utf-8').decode('utf-8') if isinstance(v, str) else v + for k, v in m.to_dict().items()} + for m in messages + ] + prompt = tokenizer.apply_chat_template(**chat_template_args) + if DEBUG >= 3: print(f"!!! Prompt (UTF-8 encoded): {prompt}") + return prompt def parse_message(data: dict): @@ -213,11 +225,16 @@ class ChatGPTAPI: cors.add(self.app.router.add_post("/download", self.handle_post_download), {"*": cors_options}) cors.add(self.app.router.add_get("/topology", self.handle_get_topology), {"*": cors_options}) + # Add static routes if "__compiled__" not in globals(): self.static_dir = Path(__file__).parent.parent/"tinychat" self.app.router.add_get("/", self.handle_root) self.app.router.add_static("/", self.static_dir, name="static") - self.app.router.add_static('/images/', get_exo_images_dir(), name='static_images') + + # Always add images route, regardless of compilation status + self.images_dir = get_exo_images_dir() + self.images_dir.mkdir(parents=True, exist_ok=True) + self.app.router.add_static('/images/', self.images_dir, name='static_images') self.app.middlewares.append(self.timeout_middleware) self.app.middlewares.append(self.log_request) @@ -509,20 +526,32 @@ class ChatGPTAPI: await response.write(json.dumps({'progress': get_progress_bar((result[0]), (result[1]))}).encode('utf-8') + b'\n') elif isinstance(result, np.ndarray): - im = Image.fromarray(np.array(result)) - images_folder = get_exo_images_dir() - # Save the image to a file - image_filename = f"{_request_id}.png" - image_path = images_folder/image_filename - im.save(image_path) - image_url = request.app.router['static_images'].url_for(filename=image_filename) - base_url = f"{request.scheme}://{request.host}" - # Construct the full URL correctly - full_image_url = base_url + str(image_url) - - await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') - if is_finished: - await response.write_eof() + try: + im = Image.fromarray(np.array(result)) + # Save the image to a file + image_filename = f"{_request_id}.png" + image_path = self.images_dir/image_filename + im.save(image_path) + + # Get URL for the saved image + try: + image_url = request.app.router['static_images'].url_for(filename=image_filename) + base_url = f"{request.scheme}://{request.host}" + full_image_url = base_url + str(image_url) + + await response.write(json.dumps({'images': [{'url': str(full_image_url), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') + except KeyError as e: + if DEBUG >= 2: print(f"Error getting image URL: {e}") + # Fallback to direct file path if URL generation fails + await response.write(json.dumps({'images': [{'url': str(image_path), 'content_type': 'image/png'}]}).encode('utf-8') + b'\n') + + if is_finished: + await response.write_eof() + + except Exception as e: + if DEBUG >= 2: print(f"Error processing image: {e}") + if DEBUG >= 2: traceback.print_exc() + await response.write(json.dumps({'error': str(e)}).encode('utf-8') + b'\n') stream_task = None