写点什么

用 300 行代码手写 1 个 Spring 框架,麻雀虽小五脏俱全

作者:Tom弹架构
  • 2021 年 12 月 30 日
  • 本文字数:11184 字

    阅读完需:约 37 分钟

本文节选自《Spring 5 核心原理》

1 自定义配置

1.1 配置 application.properties 文件

为了解析方便,我们用 application.properties 来代替 application.xml 文件,具体配置内容如下:



scanPackage=com.tom.demo
复制代码

1.2 配置 web.xml 文件

大家都知道,所有依赖于 Web 容器的项目都是从读取 web.xml 文件开始的。我们先配置好 web.xml 中的内容:



<?xml version="1.0" encoding="UTF-8"?><web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://java.sun.com/xml/ns/j2ee" xmlns:javaee="http://java.sun.com/xml/ns/javaee" xmlns:web="http://java.sun.com/xml/ns/javaee/web-app_2_5.xsd" xsi:schemaLocation="http://java.sun.com/xml/ns/j2ee http://java.sun.com/xml/ns/j2ee/web-app_2_4.xsd" version="2.4"> <display-name>Gupao Web Application</display-name> <servlet> <servlet-name>gpmvc</servlet-name> <servlet-class>com.tom.mvcframework.v1.servlet.GPDispatcherServlet</servlet-class> <init-param> <param-name>contextConfigLocation</param-name> <param-value>application.properties</param-value> </init-param> <load-on-startup>1</load-on-startup> </servlet> <servlet-mapping> <servlet-name>gpmvc</servlet-name> <url-pattern>/*</url-pattern> </servlet-mapping></web-app>
复制代码


其中的 GPDispatcherServlet 是模拟 Spring 实现的核心功能类。

1.3 自定义注解

@GPService 注解如下:



package com.tom.mvcframework.annotation;import java.lang.annotation.*;@Target({ElementType.TYPE})@Retention(RetentionPolicy.RUNTIME)@Documentedpublic @interface GPService { String value() default "";}
复制代码


@GPAutowired 注解如下:



package com.tom.mvcframework.annotation;import java.lang.annotation.*;@Target({ElementType.FIELD})@Retention(RetentionPolicy.RUNTIME)@Documentedpublic @interface GPAutowired { String value() default "";}
复制代码


@GPController 注解如下:



package com.tom.mvcframework.annotation;import java.lang.annotation.*;@Target({ElementType.TYPE})@Retention(RetentionPolicy.RUNTIME)@Documentedpublic @interface GPController { String value() default "";}
复制代码


@GPRequestMapping 注解如下:



package com.tom.mvcframework.annotation;import java.lang.annotation.*;@Target({ElementType.TYPE,ElementType.METHOD})@Retention(RetentionPolicy.RUNTIME)@Documentedpublic @interface GPRequestMapping { String value() default "";}
复制代码


@GPRequestParam 注解如下:



package com.tom.mvcframework.annotation;import java.lang.annotation.*;@Target({ElementType.PARAMETER})@Retention(RetentionPolicy.RUNTIME)@Documentedpublic @interface GPRequestParam { String value() default "";}
复制代码

1.4 配置注解

配置业务实现类 DemoService:



package com.tom.demo.service.impl;import com.tom.demo.service.IDemoService;import com.tom.mvcframework.annotation.GPService;/** * 核心业务逻辑 */@GPServicepublic class DemoService implements IDemoService{ public String get(String name) { return "My name is " + name; }}
复制代码


配置请求入口类 DemoAction:



package com.tom.demo.mvc.action;import java.io.IOException;import javax.servlet.http.HttpServletRequest;import javax.servlet.http.HttpServletResponse;import com.tom.demo.service.IDemoService;import com.tom.mvcframework.annotation.GPAutowired;import com.tom.mvcframework.annotation.GPController;import com.tom.mvcframework.annotation.GPRequestMapping;import com.tom.mvcframework.annotation.GPRequestParam;@GPController@GPRequestMapping("/demo")public class DemoAction { @GPAutowired private IDemoService demoService; @GPRequestMapping("/query") public void query(HttpServletRequest req, HttpServletResponse resp, @GPRequestParam("name") String name){ String result = demoService.get(name); try { resp.getWriter().write(result); } catch (IOException e) { e.printStackTrace(); } } @GPRequestMapping("/add") public void add(HttpServletRequest req, HttpServletResponse resp, @GPRequestParam("a") Integer a, @GPRequestParam("b") Integer b){ try { resp.getWriter().write(a + "+" + b + "=" + (a + b)); } catch (IOException e) { e.printStackTrace(); } } @GPRequestMapping("/remove") public void remove(HttpServletRequest req,HttpServletResponse resp, @GPRequestParam("id") Integer id){ }}
复制代码


至此,配置全部完成。

2 容器初始化 1.0 版本

所有的核心逻辑全部写在 init()方法中,代码如下:



package com.tom.mvcframework.v1.servlet;import com.tom.mvcframework.annotation.GPAutowired;import com.tom.mvcframework.annotation.GPController;import com.tom.mvcframework.annotation.GPRequestMapping;import com.tom.mvcframework.annotation.GPService;import javax.servlet.ServletConfig;import javax.servlet.ServletException;import javax.servlet.http.HttpServlet;import javax.servlet.http.HttpServletRequest;import javax.servlet.http.HttpServletResponse;import java.io.File;import java.io.IOException;import java.io.InputStream;import java.lang.reflect.Field;import java.lang.reflect.Method;import java.net.URL;import java.util.*;
public class GPDispatcherServlet extends HttpServlet { private Map<String,Object> mapping = new HashMap<String, Object>(); @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {this.doPost(req,resp);} @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {} @Override public void init(ServletConfig config) throws ServletException { InputStream is = null; try{ Properties configContext = new Properties(); is = this.getClass().getClassLoader().getResourceAsStream(config.getInitParameter ("contextConfigLocation")); configContext.load(is); String scanPackage = configContext.getProperty("scanPackage"); doScanner(scanPackage); for (String className : mapping.keySet()) { if(!className.contains(".")){continue;} Class<?> clazz = Class.forName(className); if(clazz.isAnnotationPresent(GPController.class)){ mapping.put(className,clazz.newInstance()); String baseUrl = ""; if (clazz.isAnnotationPresent(GPRequestMapping.class)) { GPRequestMapping requestMapping = clazz.getAnnotation (GPRequestMapping.class); baseUrl = requestMapping.value(); } Method[] methods = clazz.getMethods(); for (Method method : methods) { if(!method.isAnnotationPresent(GPRequestMapping.class)){ continue; } GPRequestMapping requestMapping = method.getAnnotation (GPRequestMapping.class); String url = (baseUrl + "/" + requestMapping.value()).replaceAll("/+", "/"); mapping.put(url, method); System.out.println("Mapped " + url + "," + method); } }else if(clazz.isAnnotationPresent(GPService.class)){ GPService service = clazz.getAnnotation(GPService.class); String beanName = service.value(); if("".equals(beanName)){beanName = clazz.getName();} Object instance = clazz.newInstance(); mapping.put(beanName,instance); for (Class<?> i : clazz.getInterfaces()) { mapping.put(i.getName(),instance); } }else {continue;} } for (Object object : mapping.values()) { if(object == null){continue;} Class clazz = object.getClass(); if(clazz.isAnnotationPresent(GPController.class)){ Field [] fields = clazz.getDeclaredFields(); for (Field field : fields) { if(!field.isAnnotationPresent(GPAutowired.class)){continue; } GPAutowired autowired = field.getAnnotation(GPAutowired.class); String beanName = autowired.value(); if("".equals(beanName)){beanName = field.getType().getName();} field.setAccessible(true); try { field.set(mapping.get(clazz.getName()),mapping.get(beanName)); } catch (IllegalAccessException e) { e.printStackTrace(); } } } } } catch (Exception e) { }finally { if(is != null){ try {is.close();} catch (IOException e) { e.printStackTrace(); } } } System.out.print("GP MVC Framework is init"); } private void doScanner(String scanPackage) { URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll ("\\.","/")); File classDir = new File(url.getFile()); for (File file : classDir.listFiles()) { if(file.isDirectory()){ doScanner(scanPackage + "." + file.getName());}else { if(!file.getName().endsWith(".class")){continue;} String clazzName = (scanPackage + "." + file.getName().replace(".class","")); mapping.put(clazzName,null); } } }}
复制代码

3 请求执行

重点实现 doGet()和 doPost()方法,实际上就是在 doGet()和 doPost()方法中调用 doDispatch()方法,具体代码如下:



public class GPDispatcherServlet extends HttpServlet { private Map<String,Object> mapping = new HashMap<String, Object>(); @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {this.doPost(req,resp);} @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { try { doDispatch(req,resp); } catch (Exception e) { resp.getWriter().write("500 Exception " + Arrays.toString(e.getStackTrace())); } } private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception { String url = req.getRequestURI(); String contextPath = req.getContextPath(); url = url.replace(contextPath, "").replaceAll("/+", "/"); if(!this.mapping.containsKey(url)){resp.getWriter().write("404 Not Found!!");return;} Method method = (Method) this.mapping.get(url); Map<String,String[]> params = req.getParameterMap(); method.invoke(this.mapping.get(method.getDeclaringClass().getName()),new Object[]{req,resp,params.get("name")[0]}); } @Override public void init(ServletConfig config) throws ServletException { ... }
}
复制代码

4 优化并实现 2.0 版本

在 1.0 版本上进行优化,采用常用的设计模式(工厂模式、单例模式、委派模式、策略模式),将 init()方法中的代码进行封装。按照之前的实现思路,先搭基础框架,再“填肉注血”,具体代码如下:



//初始化阶段@Overridepublic void init(ServletConfig config) throws ServletException {
//1. 加载配置文件 doLoadConfig(config.getInitParameter("contextConfigLocation"));
//2. 扫描相关的类 doScanner(contextConfig.getProperty("scanPackage")); //3. 初始化扫描到的类,并且将它们放入IoC容器中 doInstance(); //4. 完成依赖注入 doAutowired();
//5. 初始化HandlerMapping initHandlerMapping();
System.out.println("GP Spring framework is init.");
}
复制代码


声明全局成员变量,其中 IoC 容器就是注册时单例的具体案例:



//保存application.properties配置文件中的内容private Properties contextConfig = new Properties();
//保存扫描的所有的类名private List<String> classNames = new ArrayList<String>();
//传说中的IoC容器,我们来揭开它的神秘面纱//为了简化程序,暂时不考虑ConcurrentHashMap//主要还是关注设计思想和原理private Map<String,Object> ioc = new HashMap<String,Object>();
//保存url和Method的对应关系private Map<String,Method> handlerMapping = new HashMap<String,Method>();
复制代码


实现 doLoadConfig()方法:



//加载配置文件private void doLoadConfig(String contextConfigLocation) { //直接通过类路径找到Spring主配置文件所在的路径 //并且将其读取出来放到Properties对象中 //相当于将scanPackage=com.tom.demo保存到了内存中 InputStream fis = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation); try { contextConfig.load(fis); } catch (IOException e) { e.printStackTrace(); }finally { if(null != fis){ try { fis.close(); } catch (IOException e) { e.printStackTrace(); } } }}
复制代码


实现 doScanner()方法:



//扫描相关的类private void doScanner(String scanPackage) { //scanPackage = com.tom.demo ,存储的是包路径 //转换为文件路径,实际上就是把.替换为/ URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll ("\\.","/")); File classPath = new File(url.getFile()); for (File file : classPath.listFiles()) { if(file.isDirectory()){ doScanner(scanPackage + "." + file.getName()); }else{ if(!file.getName().endsWith(".class")){ continue;} String className = (scanPackage + "." + file.getName().replace(".class","")); classNames.add(className); } }}
复制代码


实现 doInstance()方法,doInstance()方法就是工厂模式的具体实现:



private void doInstance() { //初始化,为DI做准备 if(classNames.isEmpty()){return;}
try { for (String className : classNames) { Class<?> clazz = Class.forName(className);
//什么样的类才需要初始化呢? //加了注解的类才初始化,怎么判断? //为了简化代码逻辑,主要体会设计思想,只用@Controller和@Service举例, //@Componment等就不一一举例了 if(clazz.isAnnotationPresent(GPController.class)){ Object instance = clazz.newInstance(); //Spring默认类名首字母小写 String beanName = toLowerFirstCase(clazz.getSimpleName()); ioc.put(beanName,instance); }else if(clazz.isAnnotationPresent(GPService.class)){ //1. 自定义的beanName GPService service = clazz.getAnnotation(GPService.class); String beanName = service.value(); //2. 默认类名首字母小写 if("".equals(beanName.trim())){ beanName = toLowerFirstCase(clazz.getSimpleName()); }
Object instance = clazz.newInstance(); ioc.put(beanName,instance); //3. 根据类型自动赋值,这是投机取巧的方式 for (Class<?> i : clazz.getInterfaces()) { if(ioc.containsKey(i.getName())){ throw new Exception("The “" + i.getName() + "” is exists!!"); } //把接口的类型直接当成key ioc.put(i.getName(),instance); } }else { continue; }
} }catch (Exception e){ e.printStackTrace(); }
}
复制代码


为了处理方便,自己实现了 toLowerFirstCase()方法,来实现类名首字母小写,具体代码如下:



//将类名首字母改为小写private String toLowerFirstCase(String simpleName) { char [] chars = simpleName.toCharArray(); //之所以要做加法,是因为大、小写字母的ASCII码相差32 //而且大写字母的ASCII码要小于小写字母的ASCII码 //在Java中,对char做算术运算实际上就是对ASCII码做算术运算 chars[0] += 32; return String.valueOf(chars);}
复制代码


实现 doAutowired()方法:



//自动进行依赖注入private void doAutowired() { if(ioc.isEmpty()){return;}
for (Map.Entry<String, Object> entry : ioc.entrySet()) { //获取所有的字段,包括private、protected、default类型的 //正常来说,普通的OOP编程只能获得public类型的字段 Field[] fields = entry.getValue().getClass().getDeclaredFields(); for (Field field : fields) { if(!field.isAnnotationPresent(GPAutowired.class)){continue;} GPAutowired autowired = field.getAnnotation(GPAutowired.class);
//如果用户没有自定义beanName,默认就根据类型注入 //这个地方省去了对类名首字母小写的情况的判断,这个作为课后作业请“小伙伴们”自己去实现 String beanName = autowired.value().trim(); if("".equals(beanName)){ //获得接口的类型,作为key,稍后用这个key到IoC容器中取值 beanName = field.getType().getName(); }
//如果是public以外的类型,只要加了@Autowired注解都要强制赋值 //反射中叫作暴力访问 field.setAccessible(true);
try { //用反射机制动态给字段赋值 field.set(entry.getValue(),ioc.get(beanName)); } catch (IllegalAccessException e) { e.printStackTrace(); }

}
}

}
复制代码


实现 initHandlerMapping()方法,HandlerMapping 就是策略模式的应用案例:



//初始化url和Method的一对一关系private void initHandlerMapping() { if(ioc.isEmpty()){ return; }
for (Map.Entry<String, Object> entry : ioc.entrySet()) { Class<?> clazz = entry.getValue().getClass();
if(!clazz.isAnnotationPresent(GPController.class)){continue;}

//保存写在类上面的@GPRequestMapping("/demo") String baseUrl = ""; if(clazz.isAnnotationPresent(GPRequestMapping.class)){ GPRequestMapping requestMapping = clazz.getAnnotation(GPRequestMapping.class); baseUrl = requestMapping.value(); }
//默认获取所有的public类型的方法 for (Method method : clazz.getMethods()) { if(!method.isAnnotationPresent(GPRequestMapping.class)){continue;}
GPRequestMapping requestMapping = method.getAnnotation(GPRequestMapping.class); //优化 String url = ("/" + baseUrl + "/" + requestMapping.value()) .replaceAll("/+","/"); handlerMapping.put(url,method); System.out.println("Mapped :" + url + "," + method);
}

}

}
复制代码


到这里初始化的工作完成,接下来实现运行的逻辑,来看 doGet()和 doPost()方法的代码:



@Overrideprotected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { this.doPost(req,resp);}
@Overrideprotected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
//运行阶段 try { doDispatch(req,resp); } catch (Exception e) { e.printStackTrace(); resp.getWriter().write("500 Exection,Detail : " + Arrays.toString(e.getStackTrace())); }}
复制代码


doPost()方法中用了委派模式,委派模式的具体逻辑在 doDispatch()方法中实现:



private void doDispatch(HttpServletRequest req, HttpServletResponse resp)throws Exception { String url = req.getRequestURI(); String contextPath = req.getContextPath(); url = url.replaceAll(contextPath,"").replaceAll("/+","/"); if(!this.handlerMapping.containsKey(url)){ resp.getWriter().write("404 Not Found!!"); return; } Method method = this.handlerMapping.get(url); //第一个参数:方法所在的实例 //第二个参数:调用时所需要的实参
Map<String,String[]> params = req.getParameterMap(); //投机取巧的方式 String beanName = toLowerFirstCase(method.getDeclaringClass().getSimpleName()); method.invoke(ioc.get(beanName),new Object[]{req,resp,params.get("name")[0]}); //System.out.println(method);}
复制代码


在以上代码中,doDispatch()虽然完成了动态委派并进行了反射调用,但对 url 参数的处理还是静态的。要实现 url 参数的动态获取,其实有些复杂。我们可以优化 doDispatch()方法的实现,代码如下:



private void doDispatch(HttpServletRequest req, HttpServletResponse resp)throws Exception { String url = req.getRequestURI(); String contextPath = req.getContextPath(); url = url.replaceAll(contextPath,"").replaceAll("/+","/"); if(!this.handlerMapping.containsKey(url)){ resp.getWriter().write("404 Not Found!!"); return; }
Method method = this.handlerMapping.get(url); //第一个参数:方法所在的实例 //第二个参数:调用时所需要的实参 Map<String,String[]> params = req.getParameterMap(); //获取方法的形参列表 Class<?> [] parameterTypes = method.getParameterTypes(); //保存请求的url参数列表 Map<String,String[]> parameterMap = req.getParameterMap(); //保存赋值参数的位置 Object [] paramValues = new Object[parameterTypes.length]; //根据参数位置动态赋值 for (int i = 0; i < parameterTypes.length; i ++){ Class parameterType = parameterTypes[i]; if(parameterType == HttpServletRequest.class){ paramValues[i] = req; continue; }else if(parameterType == HttpServletResponse.class){ paramValues[i] = resp; continue; }else if(parameterType == String.class){
//提取方法中加了注解的参数 Annotation[] [] pa = method.getParameterAnnotations(); for (int j = 0; j < pa.length ; j ++) { for(Annotation a : pa[i]){ if(a instanceof GPRequestParam){ String paramName = ((GPRequestParam) a).value(); if(!"".equals(paramName.trim())){ String value = Arrays.toString(parameterMap.get(paramName)) .replaceAll("\\[|\\]","") .replaceAll("\\s",","); paramValues[i] = value; } } } }
} } //投机取巧的方式 //通过反射获得Method所在的Class,获得Class之后还要获得Class的名称 //再调用toLowerFirstCase获得beanName String beanName = toLowerFirstCase(method.getDeclaringClass().getSimpleName()); method.invoke(ioc.get(beanName),new Object[]{req,resp,params.get("name")[0]});}
复制代码


关注微信公众号『 Tom 弹架构 』回复“Spring”可获取完整源码。


本文为“Tom 弹架构”原创,转载请注明出处。技术在于分享,我分享我快乐!如果您有任何建议也可留言评论或私信,您的支持是我坚持创作的动力。关注微信公众号『 Tom 弹架构 』可获取更多技术干货!


原创不易,坚持很酷,都看到这里了,小伙伴记得点赞、收藏、在看,一键三连加关注!如果你觉得内容太干,可以分享转发给朋友滋润滋润!

发布于: 刚刚
用户头像

Tom弹架构

关注

不只做一个技术者,更要做一个思考者 2021.10.22 加入

畅销书作者,代表作品:《Spring 5核心原理》、《Netty 4核心原理》、《设计模式就该这样学》

评论

发布
暂无评论
用300行代码手写1个Spring框架,麻雀虽小五脏俱全