import boto3
import pandas as pd
import json
import time
from tqdm import tqdm
import concurrent.futures
import os
from botocore.exceptions import ClientError
import numpy as np
import base64
from urllib.parse import urlparse
import io
from PIL import Image
# 定义CSV文件路径
csv_file = 'combined_dataset_fixed.csv'
# 读取包含图片路径的CSV文件
df = pd.read_csv(csv_file)
# 初始化Bedrock客户端
bedrock_runtime = boto3.client(
service_name='bedrock-runtime',
region_name='us-east-1' # 确保Claude模型在此区域可用
)
# 初始化S3客户端
s3_client = boto3.client('s3', region_name='us-west-1')
# 确保Bedrock有权限访问S3桶
def check_s3_permissions():
try:
# 检查桶是否存在
s3_client.head_bucket(Bucket="video-moderation-dataset")
print("成功连接到S3桶 'video-moderation-dataset'")
return True
except Exception as e:
print(f"无法访问S3桶: {str(e)}")
# 如果只处理本地文件,可以继续
return True
# 模型ID
# model_id = 'us.anthropic.claude-3-7-sonnet-20250219-v1:0'
model_id = 'us.amazon.nova-pro-v1:0'
# model_id='us.amazon.nova-lite-v1:0'
# 判断路径是S3路径还是本地路径
def is_s3_path(path):
return path.startswith('s3://')
# 从S3获取图片字节
def get_image_from_s3(s3_uri):
try:
parsed_url = urlparse(s3_uri)
bucket_name = parsed_url.netloc
object_key = parsed_url.path.lstrip('/')
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
image_bytes = response['Body'].read()
return image_bytes
except Exception as e:
print(f"从S3获取图片失败: {str(e)}")
return None
# 从本地文件系统获取图片字节
def get_image_from_local(file_path):
try:
with open(file_path, 'rb') as f:
image_bytes = f.read()
return image_bytes
except Exception as e:
print(f"从本地文件系统获取图片失败: {str(e)}")
return None
# 获取图片字节,无论是从S3还是本地
def get_image_bytes(path):
if is_s3_path(path):
return get_image_from_s3(path)
else:
return get_image_from_local(path)
# 处理图片,确保格式正确
def process_image(image_bytes):
try:
# 使用PIL打开图片
image = Image.open(io.BytesIO(image_bytes))
# 转换为RGB模式(如果是RGBA或其他模式)
if image.mode != 'RGB':
image = image.convert('RGB')
# 调整大小,确保不超过API限制
max_size = 4096 # 根据API限制调整
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.LANCZOS)
# 转换为JPEG格式的字节
buffer = io.BytesIO()
image.save(buffer, format="JPEG", quality=85)
processed_bytes = buffer.getvalue()
return processed_bytes
except Exception as e:
print(f"处理图片时出错: {str(e)}")
return image_bytes # 如果处理失败,返回原始字节
# 分析单个图片
def analyze_image(image_path, label):
print(f"处理图片: {image_path}")
try:
# 获取图片字节
image_bytes = get_image_bytes(image_path)
if image_bytes is None:
return {
'image_path': image_path,
'label': label,
'nova_lite_result': 'error',
'nova_lite_response': '无法获取图片数据'
}
# 处理图片,确保格式正确
processed_image = process_image(image_bytes)
# 获取图片扩展名
image_ext = "jpeg" # 默认为jpeg,因为我们在process_image中转换为JPEG
# prompt = """你是一个专业的图片内容审核专家,负责识别图片中的暴力和色情内容。请仔细分析图片内容并做出精确判断。
# 请遵循以下分类标准:
# 【暴力内容定义】
# - Violent: 图片包含真实的肢体冲突、打斗、斗殴、攻击行为,特别是街头打架、武器威胁、流血场景、明显的身体伤害或任何形式的人身伤害。
# - Non-violent: 图片展示和平的人类活动,如体育运动(即使是竞技性的)、日常生活场景(吃饭、行走、交谈)、工作场景等,没有真实的暴力冲突。
# 【色情内容定义】
# - Porn: 图片包含露骨的性行为、生殖器官暴露、性交行为、明确的性意图活动、裸体(非艺术性质)、性挑逗姿势等明确色情内容。
# - Non-porn: 图片可能包含比基尼、泳装、部分暴露但不露骨的内容,虽可能具有性感元素,但不展示明确的性行为或完全裸露。
# 分析步骤:
# 1. 首先详细描述图片中看到的内容,包括人物、动作、环境和上下文
# 2. 分别判断暴力和色情两个维度
# 3. 给出最终分类结果,必须是以下四种之一:
# - "Violent"(有暴力,无色情)
# - "Non-violent"(无暴力,无色情)
# - "Porn"(有色情,无论是否有暴力)
# - "Non-porn"(无色情,可能包含性感元素但不露骨)
# 如有任何不确定性,请倾向于将内容标记为"Violent"或"Porn"。你的回答必须包含上述四个标签之一,后面跟着简短的解释理由。"""
prompt = """你是一个专业的图片内容审核专家,负责识别图片中的暴力和色情内容。请仔细分析图片内容并做出精确判断。
请遵循以下分类标准:
- Violent: 图片必须包含以下一项或多项:
* 真实人物之间的打斗、殴打、攻击行为(注意:与表演、电影、体育竞技区分)
* 明显的流血、伤口或身体伤害
* 使用武器威胁或伤害他人
* 明显的肢体冲突和人身伤害场景
* 街头斗殴、帮派暴力或群体攻击行为
* 体育场上发生斗殴事件
- Non-violent: 不包含上述暴力元素的图片,即使是:
* 竞技性体育运动(包括拳击、摔跤、格斗比赛等)
* 表演或电影中的打斗场景
* 游戏或动画中的非真实暴力
* 人们进行日常活动(吃饭、行走、交谈、工作等)
【色情内容精确定义】
- Porn: 图片必须包含以下一项或多项:
* 明确的性行为或性交场景
* 裸露的生殖器官
* 裸露的成人胸部(女性)
* 明显的性行为暗示姿势或动作
* 性器官特写或强调
* 明确的性行为表演
- Non-porn: 不包含上述色情元素的图片,即使是:
* 穿着比基尼、泳装的人
* 部分暴露但不完全裸露的内容
* 健身、瑜伽或体育活动中的紧身服装
* 具有性感元素但不露骨的内容
* 艺术性质的裸体(如雕塑、古典艺术)
分析步骤:
1. 首先详细描述图片中看到的内容,包括人物、动作、环境和上下文
2. 分别判断暴力和色情两个维度
3. 给出最终分类结果,必须是以下四种之一:
- "Violent"(有暴力,无色情)
- "Non-violent"(无暴力,无色情)
- "Porn"(有色情,无论是否有暴力)
- "Non-porn"(无色情,可能包含性感元素但不露骨)
如有任何不确定性,请倾向于将内容标记为"Violent"或"Porn"。你的回答必须包含上述四个标签之一,后面跟着简短的解释理由。"""
max_retries = 2
retry_count = 0
response = None # 初始化response变量
# 打印图片路径和大小信息进行调试
print(f"图片路径: {image_path}")
print(f"处理后图片大小: {len(processed_image)} 字节")
# 调用Bedrock Converse API,直接传入图片字节
while retry_count < max_retries:
try:
response = bedrock_runtime.converse(
modelId=model_id,
messages=[
{
'role': 'user',
'content': [
{
'text': prompt
},
{
'image': {
'format': image_ext,
'source': {
'bytes': processed_image
}
}
}
]
}
]
)
break
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
print(f"调用模型失败,已达到最大重试次数 ({max_retries})。错误: {str(e)}")
raise e
else:
print(f"调用模型失败,正在进行第 {retry_count} 次重试。错误: {str(e)}")
sleep_time = 30
time.sleep(sleep_time)
# 检查response是否为None
if response is None:
return {
'image_path': image_path,
'label': label,
'nova_lite_result': 'error',
'nova_lite_response': '无法获取模型响应'
}
# 解析响应
try:
response_text = response['output']['message']['content'][0]['text']
except KeyError:
try:
response_text = response['messages'][0]['content'][0]['text']
except KeyError:
response_text = "无法从响应中提取文本"
# 根据标签类型提取判断结果
if label in ['violent', 'non-violent']:
# 提取判断结果(violent或non-violent)
if 'violent' in response_text.lower() and not 'non-violent' in response_text.lower():
result = 'violent'
elif 'non-violent' in response_text.lower():
result = 'non-violent'
else:
# 如果无法明确判断,则进一步分析响应
if any(word in response_text.lower() for word in ['fight', 'attack', 'violence', 'aggressive', 'harm', 'push', 'hit', 'slap', 'touch', 'contact', 'fall', 'strike', 'throw', 'conflict', 'swing', 'wave', 'stick', 'bat']):
result = 'violent'
else:
result = 'non-violent'
elif label in ['porn', 'non-porn']:
# 提取判断结果(porn或non-porn)
if 'porn' in response_text.lower() and not 'non-porn' in response_text.lower():
result = 'porn'
elif 'non-porn' in response_text.lower():
result = 'non-porn'
else:
# 如果无法明确判断,则进一步分析响应
if any(word in response_text.lower() for word in ['nude', 'naked', 'sexual', 'explicit', 'adult', 'inappropriate']):
result = 'porn'
else:
result = 'non-porn'
else:
# 默认情况,直接使用标签作为结果
result = label
# 添加一些延迟以避免API限制
time.sleep(0.5)
return {
'image_path': image_path,
'label': label,
'nova_lite_result': result,
'nova_lite_response': response_text
}
except Exception as e:
print(f"处理图片 {image_path} 时出错: {str(e)}")
return {
'image_path': image_path,
'label': label,
'nova_lite_result': 'error',
'nova_lite_response': str(e)
}
# 主函数
def main():
# 检查S3权限
if not check_s3_permissions():
print("请确保已正确配置S3桶权限")
return
# 创建结果列表
results = []
# 获取图片路径和标签
image_data = list(zip(df['image_path'], df['label']))
print(f"开始分析 {len(image_data)} 个图片 ...")
# 使用线程池并行处理图片(限制并发数以避免API限制)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
futures = []
for image_path, label in image_data:
future = executor.submit(analyze_image, image_path, label)
futures.append(future)
# 使用tqdm显示进度
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="分析图片"):
result = future.result()
if result:
results.append(result)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
successful_results = results_df[results_df['nova_lite_result'] != 'error']
# 保存结果到新的CSV文件
results_df.to_csv('image_analysis_results_nova.csv', index=False)
print(f"总共处理 {len(results_df)} 个图片,成功 {len(successful_results)} 个,失败 {len(results_df) - len(successful_results)} 个")
# 计算准确率
accuracy = (successful_results['label'] == successful_results['nova_lite_result']).mean()
print(f"Bedrock模型分析完成。准确率: {accuracy:.2%}")
# 打印混淆矩阵
print("\n混淆矩阵:")
confusion = pd.crosstab(
successful_results['label'],
successful_results['nova_lite_result'],
rownames=['实际'],
colnames=['预测']
)
print(confusion)
if __name__ == "__main__":
main()
评论