import boto3import pandas as pdimport jsonimport timefrom tqdm import tqdmimport concurrent.futuresimport osfrom botocore.exceptions import ClientErrorimport numpy as npimport base64from urllib.parse import urlparseimport iofrom PIL import Image
# 定义CSV文件路径csv_file = 'combined_dataset_fixed.csv'
# 读取包含图片路径的CSV文件df = pd.read_csv(csv_file)
# 初始化Bedrock客户端bedrock_runtime = boto3.client( service_name='bedrock-runtime', region_name='us-east-1' # 确保Claude模型在此区域可用)
# 初始化S3客户端s3_client = boto3.client('s3', region_name='us-west-1')
# 确保Bedrock有权限访问S3桶def check_s3_permissions(): try: # 检查桶是否存在 s3_client.head_bucket(Bucket="video-moderation-dataset") print("成功连接到S3桶 'video-moderation-dataset'") return True except Exception as e: print(f"无法访问S3桶: {str(e)}") # 如果只处理本地文件,可以继续 return True
# 模型ID# model_id = 'us.anthropic.claude-3-7-sonnet-20250219-v1:0'model_id = 'us.amazon.nova-pro-v1:0'# model_id='us.amazon.nova-lite-v1:0'
# 判断路径是S3路径还是本地路径def is_s3_path(path): return path.startswith('s3://')
# 从S3获取图片字节def get_image_from_s3(s3_uri): try: parsed_url = urlparse(s3_uri) bucket_name = parsed_url.netloc object_key = parsed_url.path.lstrip('/') response = s3_client.get_object(Bucket=bucket_name, Key=object_key) image_bytes = response['Body'].read() return image_bytes except Exception as e: print(f"从S3获取图片失败: {str(e)}") return None
# 从本地文件系统获取图片字节def get_image_from_local(file_path): try: with open(file_path, 'rb') as f: image_bytes = f.read() return image_bytes except Exception as e: print(f"从本地文件系统获取图片失败: {str(e)}") return None
# 获取图片字节,无论是从S3还是本地def get_image_bytes(path): if is_s3_path(path): return get_image_from_s3(path) else: return get_image_from_local(path)
# 处理图片,确保格式正确def process_image(image_bytes): try: # 使用PIL打开图片 image = Image.open(io.BytesIO(image_bytes)) # 转换为RGB模式(如果是RGBA或其他模式) if image.mode != 'RGB': image = image.convert('RGB') # 调整大小,确保不超过API限制 max_size = 4096 # 根据API限制调整 if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) image = image.resize(new_size, Image.LANCZOS) # 转换为JPEG格式的字节 buffer = io.BytesIO() image.save(buffer, format="JPEG", quality=85) processed_bytes = buffer.getvalue() return processed_bytes except Exception as e: print(f"处理图片时出错: {str(e)}") return image_bytes # 如果处理失败,返回原始字节
# 分析单个图片def analyze_image(image_path, label): print(f"处理图片: {image_path}") try: # 获取图片字节 image_bytes = get_image_bytes(image_path) if image_bytes is None: return { 'image_path': image_path, 'label': label, 'nova_lite_result': 'error', 'nova_lite_response': '无法获取图片数据' } # 处理图片,确保格式正确 processed_image = process_image(image_bytes) # 获取图片扩展名 image_ext = "jpeg" # 默认为jpeg,因为我们在process_image中转换为JPEG # prompt = """你是一个专业的图片内容审核专家,负责识别图片中的暴力和色情内容。请仔细分析图片内容并做出精确判断。
# 请遵循以下分类标准:
# 【暴力内容定义】 # - Violent: 图片包含真实的肢体冲突、打斗、斗殴、攻击行为,特别是街头打架、武器威胁、流血场景、明显的身体伤害或任何形式的人身伤害。 # - Non-violent: 图片展示和平的人类活动,如体育运动(即使是竞技性的)、日常生活场景(吃饭、行走、交谈)、工作场景等,没有真实的暴力冲突。
# 【色情内容定义】 # - Porn: 图片包含露骨的性行为、生殖器官暴露、性交行为、明确的性意图活动、裸体(非艺术性质)、性挑逗姿势等明确色情内容。 # - Non-porn: 图片可能包含比基尼、泳装、部分暴露但不露骨的内容,虽可能具有性感元素,但不展示明确的性行为或完全裸露。
# 分析步骤: # 1. 首先详细描述图片中看到的内容,包括人物、动作、环境和上下文 # 2. 分别判断暴力和色情两个维度 # 3. 给出最终分类结果,必须是以下四种之一: # - "Violent"(有暴力,无色情) # - "Non-violent"(无暴力,无色情) # - "Porn"(有色情,无论是否有暴力) # - "Non-porn"(无色情,可能包含性感元素但不露骨)
# 如有任何不确定性,请倾向于将内容标记为"Violent"或"Porn"。你的回答必须包含上述四个标签之一,后面跟着简短的解释理由。"""
prompt = """你是一个专业的图片内容审核专家,负责识别图片中的暴力和色情内容。请仔细分析图片内容并做出精确判断。
请遵循以下分类标准:
- Violent: 图片必须包含以下一项或多项: * 真实人物之间的打斗、殴打、攻击行为(注意:与表演、电影、体育竞技区分) * 明显的流血、伤口或身体伤害 * 使用武器威胁或伤害他人 * 明显的肢体冲突和人身伤害场景 * 街头斗殴、帮派暴力或群体攻击行为 * 体育场上发生斗殴事件
- Non-violent: 不包含上述暴力元素的图片,即使是: * 竞技性体育运动(包括拳击、摔跤、格斗比赛等) * 表演或电影中的打斗场景 * 游戏或动画中的非真实暴力 * 人们进行日常活动(吃饭、行走、交谈、工作等)
【色情内容精确定义】 - Porn: 图片必须包含以下一项或多项: * 明确的性行为或性交场景 * 裸露的生殖器官 * 裸露的成人胸部(女性) * 明显的性行为暗示姿势或动作 * 性器官特写或强调 * 明确的性行为表演
- Non-porn: 不包含上述色情元素的图片,即使是: * 穿着比基尼、泳装的人 * 部分暴露但不完全裸露的内容 * 健身、瑜伽或体育活动中的紧身服装 * 具有性感元素但不露骨的内容 * 艺术性质的裸体(如雕塑、古典艺术) 分析步骤: 1. 首先详细描述图片中看到的内容,包括人物、动作、环境和上下文 2. 分别判断暴力和色情两个维度 3. 给出最终分类结果,必须是以下四种之一: - "Violent"(有暴力,无色情) - "Non-violent"(无暴力,无色情) - "Porn"(有色情,无论是否有暴力) - "Non-porn"(无色情,可能包含性感元素但不露骨)
如有任何不确定性,请倾向于将内容标记为"Violent"或"Porn"。你的回答必须包含上述四个标签之一,后面跟着简短的解释理由。"""
max_retries = 2 retry_count = 0 response = None # 初始化response变量 # 打印图片路径和大小信息进行调试 print(f"图片路径: {image_path}") print(f"处理后图片大小: {len(processed_image)} 字节") # 调用Bedrock Converse API,直接传入图片字节 while retry_count < max_retries: try: response = bedrock_runtime.converse( modelId=model_id, messages=[ { 'role': 'user', 'content': [ { 'text': prompt }, { 'image': { 'format': image_ext, 'source': { 'bytes': processed_image } } } ] } ] ) break except Exception as e: retry_count += 1 if retry_count >= max_retries: print(f"调用模型失败,已达到最大重试次数 ({max_retries})。错误: {str(e)}") raise e else: print(f"调用模型失败,正在进行第 {retry_count} 次重试。错误: {str(e)}") sleep_time = 30 time.sleep(sleep_time) # 检查response是否为None if response is None: return { 'image_path': image_path, 'label': label, 'nova_lite_result': 'error', 'nova_lite_response': '无法获取模型响应' } # 解析响应 try: response_text = response['output']['message']['content'][0]['text'] except KeyError: try: response_text = response['messages'][0]['content'][0]['text'] except KeyError: response_text = "无法从响应中提取文本" # 根据标签类型提取判断结果 if label in ['violent', 'non-violent']: # 提取判断结果(violent或non-violent) if 'violent' in response_text.lower() and not 'non-violent' in response_text.lower(): result = 'violent' elif 'non-violent' in response_text.lower(): result = 'non-violent' else: # 如果无法明确判断,则进一步分析响应 if any(word in response_text.lower() for word in ['fight', 'attack', 'violence', 'aggressive', 'harm', 'push', 'hit', 'slap', 'touch', 'contact', 'fall', 'strike', 'throw', 'conflict', 'swing', 'wave', 'stick', 'bat']): result = 'violent' else: result = 'non-violent' elif label in ['porn', 'non-porn']: # 提取判断结果(porn或non-porn) if 'porn' in response_text.lower() and not 'non-porn' in response_text.lower(): result = 'porn' elif 'non-porn' in response_text.lower(): result = 'non-porn' else: # 如果无法明确判断,则进一步分析响应 if any(word in response_text.lower() for word in ['nude', 'naked', 'sexual', 'explicit', 'adult', 'inappropriate']): result = 'porn' else: result = 'non-porn' else: # 默认情况,直接使用标签作为结果 result = label # 添加一些延迟以避免API限制 time.sleep(0.5) return { 'image_path': image_path, 'label': label, 'nova_lite_result': result, 'nova_lite_response': response_text } except Exception as e: print(f"处理图片 {image_path} 时出错: {str(e)}") return { 'image_path': image_path, 'label': label, 'nova_lite_result': 'error', 'nova_lite_response': str(e) }
# 主函数def main(): # 检查S3权限 if not check_s3_permissions(): print("请确保已正确配置S3桶权限") return # 创建结果列表 results = [] # 获取图片路径和标签 image_data = list(zip(df['image_path'], df['label'])) print(f"开始分析 {len(image_data)} 个图片 ...") # 使用线程池并行处理图片(限制并发数以避免API限制) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: futures = [] for image_path, label in image_data: future = executor.submit(analyze_image, image_path, label) futures.append(future) # 使用tqdm显示进度 for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="分析图片"): result = future.result() if result: results.append(result) # 创建结果DataFrame results_df = pd.DataFrame(results) successful_results = results_df[results_df['nova_lite_result'] != 'error'] # 保存结果到新的CSV文件 results_df.to_csv('image_analysis_results_nova.csv', index=False) print(f"总共处理 {len(results_df)} 个图片,成功 {len(successful_results)} 个,失败 {len(results_df) - len(successful_results)} 个") # 计算准确率 accuracy = (successful_results['label'] == successful_results['nova_lite_result']).mean() print(f"Bedrock模型分析完成。准确率: {accuracy:.2%}") # 打印混淆矩阵 print("\n混淆矩阵:") confusion = pd.crosstab( successful_results['label'], successful_results['nova_lite_result'], rownames=['实际'], colnames=['预测'] ) print(confusion)
if __name__ == "__main__": main()
评论